sahara/savanna/db/sqlalchemy/api.py

722 lines
21 KiB
Python

# Copyright (c) 2013 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of SQLAlchemy backend."""
import sys
import sqlalchemy as sa
from savanna.db.sqlalchemy import models as m
from savanna import exceptions as ex
from savanna.openstack.common.db import exception as db_exc
from savanna.openstack.common.db.sqlalchemy import session as db_session
from savanna.openstack.common import log as logging
LOG = logging.getLogger(__name__)
get_engine = db_session.get_engine
get_session = db_session.get_session
def get_backend():
"""The backend is this module itself."""
return sys.modules[__name__]
def model_query(model, context, session=None, project_only=True):
"""Query helper.
:param model: base model to query
:param context: context to query under
:param project_only: if present and context is user-type, then restrict
query to match the context's tenant_id.
"""
session = session or get_session()
query = session.query(model)
if project_only and not context.is_admin:
query = query.filter_by(tenant_id=context.tenant_id)
return query
def count_query(model, context, session=None, project_only=None):
"""Count query helper.
:param model: base model to query
:param context: context to query under
:param project_only: if present and context is user-type, then restrict
query to match the context's project_id.
"""
return model_query(sa.func.count(model.id), context, session, project_only)
def setup_db():
try:
engine = db_session.get_engine(sqlite_fk=True)
m.Cluster.metadata.create_all(engine)
except sa.exc.OperationalError as e:
LOG.error("Database registration exception: %s", e)
return False
return True
def drop_db():
try:
engine = db_session.get_engine(sqlite_fk=True)
m.Cluster.metadata.drop_all(engine)
except Exception as e:
LOG.error("Database shutdown exception: %s", e)
return False
return True
## Helpers for building constraints / equality checks
def constraint(**conditions):
return Constraint(conditions)
def equal_any(*values):
return EqualityCondition(values)
def not_equal(*values):
return InequalityCondition(values)
class Constraint(object):
def __init__(self, conditions):
self.conditions = conditions
def apply(self, model, query):
for key, condition in self.conditions.iteritems():
for clause in condition.clauses(getattr(model, key)):
query = query.filter(clause)
return query
class EqualityCondition(object):
def __init__(self, values):
self.values = values
def clauses(self, field):
return sa.or_([field == value for value in self.values])
class InequalityCondition(object):
def __init__(self, values):
self.values = values
def clauses(self, field):
return [field != value for value in self.values]
## Cluster ops
def _cluster_get(context, session, cluster_id):
query = model_query(m.Cluster, context, session)
return query.filter_by(id=cluster_id).first()
def cluster_get(context, cluster_id):
return _cluster_get(context, get_session(), cluster_id)
def cluster_get_all(context, **kwargs):
query = model_query(m.Cluster, context)
return query.filter_by(**kwargs).all()
def cluster_create(context, values):
values = values.copy()
cluster = m.Cluster()
node_groups = values.pop("node_groups", [])
cluster.update(values)
session = get_session()
with session.begin():
try:
cluster.save(session=session)
except db_exc.DBDuplicateEntry as e:
raise ex.DBDuplicateEntry("Duplicate entry for Cluster: %s"
% e.columns)
try:
for ng in node_groups:
node_group = m.NodeGroup()
node_group.update({"cluster_id": cluster.id})
node_group.update(ng)
node_group.save(session=session)
except db_exc.DBDuplicateEntry as e:
raise ex.DBDuplicateEntry("Duplicate entry for NodeGroup: %s"
% e.columns)
return cluster_get(context, cluster.id)
def cluster_update(context, cluster_id, values):
session = get_session()
with session.begin():
cluster = _cluster_get(context, session, cluster_id)
if cluster is None:
raise ex.NotFoundException(cluster_id,
"Cluster id '%s' not found!")
cluster.update(values)
return cluster
def cluster_destroy(context, cluster_id):
session = get_session()
with session.begin():
cluster = _cluster_get(context, session, cluster_id)
if not cluster:
raise ex.NotFoundException(cluster_id,
"Cluster id '%s' not found!")
session.delete(cluster)
## Node Group ops
def _node_group_get(context, session, node_group_id):
query = model_query(m.NodeGroup, context, session)
return query.filter_by(id=node_group_id).first()
def node_group_add(context, cluster_id, values):
session = get_session()
with session.begin():
cluster = _cluster_get(context, session, cluster_id)
if not cluster:
raise ex.NotFoundException(cluster_id,
"Cluster id '%s' not found!")
node_group = m.NodeGroup()
node_group.update({"cluster_id": cluster_id})
node_group.update(values)
session.add(node_group)
return node_group.id
def node_group_update(context, node_group_id, values):
session = get_session()
with session.begin():
node_group = _node_group_get(context, session, node_group_id)
if not node_group:
raise ex.NotFoundException(node_group_id,
"Node Group id '%s' not found!")
node_group.update(values)
def node_group_remove(context, node_group_id):
session = get_session()
with session.begin():
node_group = _node_group_get(context, session, node_group_id)
if not node_group:
raise ex.NotFoundException(node_group_id,
"Node Group id '%s' not found!")
session.delete(node_group)
## Instance ops
def _instance_get(context, session, instance_id):
query = model_query(m.Instance, context, session)
return query.filter_by(id=instance_id).first()
def instance_add(context, node_group_id, values):
session = get_session()
with session.begin():
node_group = _node_group_get(context, session, node_group_id)
if not node_group:
raise ex.NotFoundException(node_group_id,
"Node Group id '%s' not found!")
instance = m.Instance()
instance.update({"node_group_id": node_group_id})
instance.update(values)
session.add(instance)
node_group = _node_group_get(context, session, node_group_id)
node_group.count += 1
return instance.id
def instance_update(context, instance_id, values):
session = get_session()
with session.begin():
instance = _instance_get(context, session, instance_id)
if not instance:
raise ex.NotFoundException(instance_id,
"Instance id '%s' not found!")
instance.update(values)
def instance_remove(context, instance_id):
session = get_session()
with session.begin():
instance = _instance_get(context, session, instance_id)
if not instance:
raise ex.NotFoundException(instance_id,
"Instance id '%s' not found!")
session.delete(instance)
node_group_id = instance.node_group_id
node_group = _node_group_get(context, session, node_group_id)
node_group.count -= 1
## Volumes ops
def append_volume(context, instance_id, volume_id):
session = get_session()
with session.begin():
instance = _instance_get(context, session, instance_id)
if not instance:
raise ex.NotFoundException(instance_id,
"Instance id '%s' not found!")
instance.volumes.append(volume_id)
def remove_volume(context, instance_id, volume_id):
session = get_session()
with session.begin():
instance = _instance_get(context, session, instance_id)
if not instance:
raise ex.NotFoundException(instance_id,
"Instance id '%s' not found!")
instance.volumes.remove(volume_id)
## Cluster Template ops
def _cluster_template_get(context, session, cluster_template_id):
query = model_query(m.ClusterTemplate, context, session)
return query.filter_by(id=cluster_template_id).first()
def cluster_template_get(context, cluster_template_id):
return _cluster_template_get(context, get_session(), cluster_template_id)
def cluster_template_get_all(context):
query = model_query(m.ClusterTemplate, context)
return query.all()
def cluster_template_create(context, values):
values = values.copy()
cluster_template = m.ClusterTemplate()
node_groups = values.pop("node_groups", [])
cluster_template.update(values)
session = get_session()
with session.begin():
try:
cluster_template.save(session=session)
except db_exc.DBDuplicateEntry as e:
raise ex.DBDuplicateEntry("Duplicate entry for ClusterTemplate: %s"
% e.columns)
try:
for ng in node_groups:
node_group = m.TemplatesRelation()
node_group.update({"cluster_template_id": cluster_template.id})
node_group.update(ng)
node_group.save(session=session)
except db_exc.DBDuplicateEntry as e:
raise ex.DBDuplicateEntry("Duplicate entry for TemplatesRelation:"
"%s" % e.columns)
return cluster_template_get(context, cluster_template.id)
def cluster_template_destroy(context, cluster_template_id):
session = get_session()
with session.begin():
cluster_template = _cluster_template_get(context, session,
cluster_template_id)
if not cluster_template:
raise ex.NotFoundException(cluster_template_id,
"Cluster Template id '%s' not found!")
session.delete(cluster_template)
## Node Group Template ops
def _node_group_template_get(context, session, node_group_template_id):
query = model_query(m.NodeGroupTemplate, context, session)
return query.filter_by(id=node_group_template_id).first()
def node_group_template_get(context, node_group_template_id):
return _node_group_template_get(context, get_session(),
node_group_template_id)
def node_group_template_get_all(context):
query = model_query(m.NodeGroupTemplate, context)
return query.all()
def node_group_template_create(context, values):
node_group_template = m.NodeGroupTemplate()
node_group_template.update(values)
try:
node_group_template.save()
except db_exc.DBDuplicateEntry as e:
raise ex.DBDuplicateEntry("Duplicate entry for NodeGroupTemplate: %s"
% e.columns)
return node_group_template
def node_group_template_destroy(context, node_group_template_id):
session = get_session()
with session.begin():
node_group_template = _node_group_template_get(context, session,
node_group_template_id)
if not node_group_template:
raise ex.NotFoundException(
node_group_template_id,
"Node Group Template id '%s' not found!")
session.delete(node_group_template)
## Data Source ops
def _data_source_get(context, session, data_source_id):
query = model_query(m.DataSource, context, session)
return query.filter_by(id=data_source_id).first()
def data_source_get(context, data_source_id):
return _data_source_get(context, get_session(), data_source_id)
def data_source_get_all(context):
query = model_query(m.DataSource, context)
return query.all()
def data_source_create(context, values):
data_source = m.DataSource()
data_source.update(values)
try:
data_source.save()
except db_exc.DBDuplicateEntry as e:
raise ex.DBDuplicateEntry("Duplicate entry for DataSource: %s"
% e.columns)
return data_source
def data_source_destroy(context, data_source_id):
session = get_session()
with session.begin():
data_source = _data_source_get(context, session, data_source_id)
if not data_source:
raise ex.NotFoundException(data_source_id,
"Data Source id '%s' not found!")
session.delete(data_source)
## JobExecution ops
def _job_execution_get(context, session, job_execution_id):
query = model_query(m.JobExecution, context, session)
return query.filter_by(id=job_execution_id).first()
def job_execution_get(context, job_execution_id):
return _job_execution_get(context, get_session(), job_execution_id)
def job_execution_get_all(context, **kwargs):
query = model_query(m.JobExecution, context)
return query.filter_by(**kwargs).all()
def job_execution_count(context, **kwargs):
query = count_query(m.JobExecution, context)
return query.filter_by(**kwargs).first()[0]
def job_execution_create(context, values):
session = get_session()
with session.begin():
job_ex = m.JobExecution()
job_ex.update(values)
try:
job_ex.save()
except db_exc.DBDuplicateEntry as e:
raise ex.DBDuplicateEntry("Duplicate entry for JobExecution: %s"
% e.columns)
return job_ex
def job_execution_update(context, job_execution_id, values):
session = get_session()
with session.begin():
job_ex = _job_execution_get(context, session, job_execution_id)
if not job_ex:
raise ex.NotFoundException(job_execution_id,
"JobExecution id '%s' not found!")
job_ex.update(values)
return job_ex
def job_execution_destroy(context, job_execution_id):
session = get_session()
with session.begin():
job_ex = _job_execution_get(context, session, job_execution_id)
if not job_ex:
raise ex.NotFoundException(job_execution_id,
"JobExecution id '%s' not found!")
session.delete(job_ex)
## Job ops
def _job_get(context, session, job_id):
query = model_query(m.Job, context, session)
return query.filter_by(id=job_id).first()
def job_get(context, job_id):
return _job_get(context, get_session(), job_id)
def job_get_all(context):
query = model_query(m.Job, context)
return query.all()
def _append_job_binaries(context, session, from_list, to_list):
for job_binary_id in from_list:
job_binary = model_query(
m.JobBinary, context, session).filter_by(id=job_binary_id).first()
if job_binary is not None:
to_list.append(job_binary)
def job_create(context, values):
mains = values.pop("mains", [])
libs = values.pop("libs", [])
session = get_session()
with session.begin():
job = m.Job()
job.update(values)
# libs and mains are 'lazy' objects. The initialization below
# is needed here because it provides libs and mains to be initialized
# within a session even if the lists are empty
job.mains = []
job.libs = []
try:
_append_job_binaries(context, session, mains, job.mains)
_append_job_binaries(context, session, libs, job.libs)
job.save(session=session)
except db_exc.DBDuplicateEntry as e:
raise ex.DBDuplicateEntry("Duplicate entry for Job: %s"
% e.columns)
return job
def job_update(context, job_id, values):
session = get_session()
with session.begin():
job = _job_get(context, session, job_id)
if not job:
raise ex.NotFoundException(job_id,
"Job id '%s' not found!")
job.update(values)
return job
def job_destroy(context, job_id):
session = get_session()
with session.begin():
job = _job_get(context, session, job_id)
if not job:
raise ex.NotFoundException(job_id,
"Job id '%s' not found!")
session.delete(job)
## JobBinary ops
def _job_binary_get(context, session, job_binary_id):
query = model_query(m.JobBinary, context, session)
return query.filter_by(id=job_binary_id).first()
def job_binary_get_all(context):
"""Returns JobBinary objects that do not contain a data field
The data column uses deferred loading.
"""
query = model_query(m.JobBinary, context)
return query.all()
def job_binary_get(context, job_binary_id):
"""Returns a JobBinary object that does not contain a data field
The data column uses deferred loadling.
"""
return _job_binary_get(context, get_session(), job_binary_id)
def job_binary_create(context, values):
"""Returns a JobBinary that does not contain a data field
The data column uses deferred loading.
"""
job_binary = m.JobBinary()
job_binary.update(values)
try:
job_binary.save()
except db_exc.DBDuplicateEntry as e:
raise ex.DBDuplicateEntry("Duplicate entry for JobBinary: %s"
% e.columns)
return job_binary
def _check_job_binary_referenced(ctx, session, job_binary_id):
args = {"JobBinary_id": job_binary_id}
mains = model_query(m.mains_association, ctx, session,
project_only=False).filter_by(**args)
libs = model_query(m.libs_association, ctx, session,
project_only=False).filter_by(**args)
return mains.first() is not None or libs.first() is not None
def job_binary_destroy(context, job_binary_id):
session = get_session()
with session.begin():
job_binary = _job_binary_get(context, session, job_binary_id)
if not job_binary:
raise ex.NotFoundException(job_binary_id,
"JobBinary id '%s' not found!")
if _check_job_binary_referenced(context, session, job_binary_id):
raise ex.DeletionFailed("JobBinary is referenced"
"and cannot be deleted")
session.delete(job_binary)
## JobBinaryInternal ops
def _job_binary_internal_get(context, session, job_binary_internal_id):
query = model_query(m.JobBinaryInternal, context, session)
return query.filter_by(id=job_binary_internal_id).first()
def job_binary_internal_get_all(context):
"""Returns JobBinaryInternal objects that do not contain a data field
The data column uses deferred loading.
"""
query = model_query(m.JobBinaryInternal, context)
return query.all()
def job_binary_internal_get(context, job_binary_internal_id):
"""Returns a JobBinaryInternal object that does not contain a data field
The data column uses deferred loadling.
"""
return _job_binary_internal_get(context, get_session(),
job_binary_internal_id)
def job_binary_internal_get_raw_data(context, job_binary_internal_id):
"""Returns only the data field for the specified JobBinaryInternal."""
query = model_query(m.JobBinaryInternal, context)
query = query.options(sa.orm.undefer("data"))
res = query.filter_by(id=job_binary_internal_id).first()
if res is not None:
res = res.data
return res
def job_binary_internal_create(context, values):
"""Returns a JobBinaryInternal that does not contain a data field
The data column uses deferred loading.
"""
job_binary_int = m.JobBinaryInternal()
job_binary_int.update(values)
try:
job_binary_int.save()
except db_exc.DBDuplicateEntry as e:
raise ex.DBDuplicateEntry("Duplicate entry for JobBinaryInternal: %s"
% e.columns)
return job_binary_internal_get(context, job_binary_int.id)
def job_binary_internal_destroy(context, job_binary_internal_id):
session = get_session()
with session.begin():
job_binary_internal = _job_binary_internal_get(context, session,
job_binary_internal_id)
if not job_binary_internal:
raise ex.NotFoundException(job_binary_internal_id,
"JobBinaryInternal id '%s' not found!")
session.delete(job_binary_internal)