From 1d12df921cf8747b43eeb72334594267243ef975 Mon Sep 17 00:00:00 2001 From: Sergey Lukjanov Date: Tue, 1 Oct 2013 16:59:42 +0400 Subject: [PATCH] Impl multitenancy support * tenant_id added to instances and nodegroup-like objects * tenant_id will be accessibly only inside conductor * enable tenant-specific model_query by default * add missed not None checks to sqla/api * ctx is now passed instead of module in savanna/service/instances.py Fixes: bug #1218452 Change-Id: Icb0c089fd47dd422fe0c93781d3a5747eb995615 --- savanna/conductor/manager.py | 3 + savanna/conductor/resource.py | 4 +- savanna/db/sqlalchemy/api.py | 120 ++++++++++-------- savanna/db/sqlalchemy/models.py | 3 + savanna/service/instances.py | 2 +- .../unit/conductor/manager/test_clusters.py | 4 +- .../unit/conductor/manager/test_templates.py | 1 + 7 files changed, 82 insertions(+), 55 deletions(-) diff --git a/savanna/conductor/manager.py b/savanna/conductor/manager.py index 4fa68e6929..75798302c2 100644 --- a/savanna/conductor/manager.py +++ b/savanna/conductor/manager.py @@ -75,6 +75,7 @@ class ConductorManager(db_base.Base): return for node_group in node_groups: + node_group["tenant_id"] = context.tenant_id self._populate_node_group(context, node_group) def _cleanup_node_group(self, node_group): @@ -161,6 +162,7 @@ class ConductorManager(db_base.Base): """Create a Node Group from the values dictionary.""" values = copy.deepcopy(values) self._populate_node_group(context, values) + values['tenant_id'] = context.tenant_id return self.db.node_group_add(context, cluster, values) def node_group_update(self, context, node_group, values): @@ -178,6 +180,7 @@ class ConductorManager(db_base.Base): """Create an Instance from the values dictionary.""" values = copy.deepcopy(values) values = _apply_defaults(values, INSTANCE_DEFAULTS) + values['tenant_id'] = context.tenant_id return self.db.instance_add(context, node_group, values) def instance_update(self, context, instance, values): diff --git a/savanna/conductor/resource.py b/savanna/conductor/resource.py index 37ec7f30cf..84b1f425f6 100644 --- a/savanna/conductor/resource.py +++ b/savanna/conductor/resource.py @@ -172,7 +172,7 @@ class NodeGroupTemplateResource(Resource, objects.NodeGroupTemplate): class InstanceResource(Resource, objects.Instance): - _filter_fields = ['node_group_id'] + _filter_fields = ['tenant_id', 'node_group_id'] class NodeGroupResource(Resource, objects.NodeGroup): @@ -181,7 +181,7 @@ class NodeGroupResource(Resource, objects.NodeGroup): 'node_group_template': (NodeGroupTemplateResource, None) } - _filter_fields = ['id', 'cluster_id', 'cluster_template_id'] + _filter_fields = ['id', 'tenant_id', 'cluster_id', 'cluster_template_id'] class ClusterTemplateResource(Resource, objects.ClusterTemplate): diff --git a/savanna/db/sqlalchemy/api.py b/savanna/db/sqlalchemy/api.py index 8f3732463d..2da0514cac 100644 --- a/savanna/db/sqlalchemy/api.py +++ b/savanna/db/sqlalchemy/api.py @@ -37,20 +37,20 @@ def get_backend(): return sys.modules[__name__] -def model_query(model, context, session=None, project_only=None): +def model_query(model, context, session=None, project_only=True): """Query helper. :param model: base model to query :param context: context to query under :param project_only: if present and context is user-type, then restrict - query to match the context's project_id. + query to match the context's tenant_id. """ session = session or get_session() query = session.query(model) - if project_only: - query = query.filter_by(tenant_id=context.project_id) + if project_only and not context.is_admin: + query = query.filter_by(tenant_id=context.tenant_id) return query @@ -188,7 +188,6 @@ def cluster_destroy(context, cluster_id): session = get_session() with session.begin(): cluster = _cluster_get(context, session, cluster_id) - if not cluster: raise ex.NotFoundException(cluster_id, "Cluster id '%s' not found!") @@ -207,6 +206,11 @@ def node_group_add(context, cluster_id, values): session = get_session() with session.begin(): + cluster = _cluster_get(context, session, cluster_id) + if not cluster: + raise ex.NotFoundException(cluster_id, + "Cluster id '%s' not found!") + node_group = m.NodeGroup() node_group.update({"cluster_id": cluster_id}) node_group.update(values) @@ -219,14 +223,18 @@ def node_group_update(context, node_group_id, values): session = get_session() with session.begin(): node_group = _node_group_get(context, session, node_group_id) + if not node_group: + raise ex.NotFoundException(node_group_id, + "Node Group id '%s' not found!") + node_group.update(values) def node_group_remove(context, node_group_id): session = get_session() + with session.begin(): node_group = _node_group_get(context, session, node_group_id) - if not node_group: raise ex.NotFoundException(node_group_id, "Node Group id '%s' not found!") @@ -245,6 +253,11 @@ def instance_add(context, node_group_id, values): session = get_session() with session.begin(): + node_group = _node_group_get(context, session, node_group_id) + if not node_group: + raise ex.NotFoundException(node_group_id, + "Node Group id '%s' not found!") + instance = m.Instance() instance.update({"node_group_id": node_group_id}) instance.update(values) @@ -260,6 +273,10 @@ def instance_update(context, instance_id, values): session = get_session() with session.begin(): instance = _instance_get(context, session, instance_id) + if not instance: + raise ex.NotFoundException(instance_id, + "Instance id '%s' not found!") + instance.update(values) @@ -267,7 +284,6 @@ def instance_remove(context, instance_id): session = get_session() with session.begin(): instance = _instance_get(context, session, instance_id) - if not instance: raise ex.NotFoundException(instance_id, "Instance id '%s' not found!") @@ -285,6 +301,10 @@ def append_volume(context, instance_id, volume_id): session = get_session() with session.begin(): instance = _instance_get(context, session, instance_id) + if not instance: + raise ex.NotFoundException(instance_id, + "Instance id '%s' not found!") + instance.volumes.append(volume_id) @@ -292,6 +312,10 @@ def remove_volume(context, instance_id, volume_id): session = get_session() with session.begin(): instance = _instance_get(context, session, instance_id) + if not instance: + raise ex.NotFoundException(instance_id, + "Instance id '%s' not found!") + instance.volumes.remove(volume_id) @@ -344,7 +368,6 @@ def cluster_template_destroy(context, cluster_template_id): with session.begin(): cluster_template = _cluster_template_get(context, session, cluster_template_id) - if not cluster_template: raise ex.NotFoundException(cluster_template_id, "Cluster Template id '%s' not found!") @@ -387,7 +410,6 @@ def node_group_template_destroy(context, node_group_template_id): with session.begin(): node_group_template = _node_group_template_get(context, session, node_group_template_id) - if not node_group_template: raise ex.NotFoundException( node_group_template_id, @@ -429,7 +451,6 @@ def data_source_destroy(context, data_source_id): session = get_session() with session.begin(): data_source = _data_source_get(context, session, data_source_id) - if not data_source: raise ex.NotFoundException(data_source_id, "Data Source id '%s' not found!") @@ -490,7 +511,6 @@ def job_execution_destroy(context, job_execution_id): session = get_session() with session.begin(): job_ex = _job_execution_get(context, session, job_execution_id) - if not job_ex: raise ex.NotFoundException(job_execution_id, "JobExecution id '%s' not found!") @@ -514,6 +534,14 @@ def job_get_all(context): return query.all() +def _append_job_binaries(context, session, from_list, to_list): + for job_binary_id in from_list: + job_binary = model_query( + m.JobBinary, context, session).filter_by(id=job_binary_id).first() + if job_binary is not None: + to_list.append(job_binary) + + def job_create(context, values): mains = values.pop("mains", []) libs = values.pop("libs", []) @@ -528,19 +556,8 @@ def job_create(context, values): job.mains = [] job.libs = [] try: - for main in mains: - query = model_query(m.JobBinary, - context, session).filter_by(id=main) - job_binary = query.first() - if job_binary is not None: - job.mains.append(job_binary) - - for lib in libs: - query = model_query(m.JobBinary, - context, session).filter_by(id=lib) - job_binary = query.first() - if job_binary is not None: - job.libs.append(job_binary) + _append_job_binaries(context, session, mains, job.mains) + _append_job_binaries(context, session, libs, job.libs) job.save(session=session) except db_exc.DBDuplicateEntry as e: @@ -567,15 +584,19 @@ def job_destroy(context, job_id): session = get_session() with session.begin(): job = _job_get(context, session, job_id) - if not job: raise ex.NotFoundException(job_id, "Job id '%s' not found!") session.delete(job) + ## JobBinary ops +def _job_binary_get(context, session, job_binary_id): + query = model_query(m.JobBinary, context, session) + return query.filter_by(id=job_binary_id).first() + def job_binary_get_all(context): """Returns JobBinary objects that do not contain a data field @@ -591,8 +612,7 @@ def job_binary_get(context, job_binary_id): The data column uses deferred loadling. """ - query = model_query(m.JobBinary, context).filter_by(id=job_binary_id) - return query.first() + return _job_binary_get(context, get_session(), job_binary_id) def job_binary_create(context, values): @@ -612,35 +632,37 @@ def job_binary_create(context, values): return job_binary -def _check_job_binary_referenced(session, id): +def _check_job_binary_referenced(ctx, session, job_binary_id): + args = {"JobBinary_id": job_binary_id} + mains = model_query(m.mains_association, ctx, session, + project_only=False).filter_by(**args) + libs = model_query(m.libs_association, ctx, session, + project_only=False).filter_by(**args) - args = {"JobBinary_id": id} - return model_query(m.mains_association, - None, session).filter_by(**args).first() is not None or\ - model_query(m.libs_association, - None, session).filter_by(**args).first() is not None + return mains.first() is not None or libs.first() is not None def job_binary_destroy(context, job_binary_id): session = get_session() with session.begin(): - - job_binary = model_query(m.JobBinary, - context, - session).filter_by(id=job_binary_id).first() - + job_binary = _job_binary_get(context, session, job_binary_id) if not job_binary: raise ex.NotFoundException(job_binary_id, "JobBinary id '%s' not found!") - if _check_job_binary_referenced(session, job_binary.id): + if _check_job_binary_referenced(context, session, job_binary_id): raise ex.DeletionFailed("JobBinary is referenced" "and cannot be deleted") session.delete(job_binary) + ## JobBinaryInternal ops +def _job_binary_internal_get(context, session, job_binary_internal_id): + query = model_query(m.JobBinaryInternal, context, session) + return query.filter_by(id=job_binary_internal_id).first() + def job_binary_internal_get_all(context): """Returns JobBinaryInternal objects that do not contain a data field @@ -656,15 +678,14 @@ def job_binary_internal_get(context, job_binary_internal_id): The data column uses deferred loadling. """ - query = model_query(m.JobBinaryInternal, context).filter_by( - id=job_binary_internal_id) - return query.first() + return _job_binary_internal_get(context, get_session(), + job_binary_internal_id) def job_binary_internal_get_raw_data(context, job_binary_internal_id): """Returns only the data field for the specified JobBinaryInternal.""" - query = model_query(m.JobBinaryInternal, context).options( - sa.orm.undefer("data")) + query = model_query(m.JobBinaryInternal, context) + query = query.options(sa.orm.undefer("data")) res = query.filter_by(id=job_binary_internal_id).first() if res is not None: res = res.data @@ -691,13 +712,10 @@ def job_binary_internal_create(context, values): def job_binary_internal_destroy(context, job_binary_internal_id): session = get_session() with session.begin(): - - b_intrnl = model_query(m.JobBinaryInternal, - context - ).filter_by(id=job_binary_internal_id).first() - - if not b_intrnl: + job_binary_internal = _job_binary_internal_get(context, session, + job_binary_internal_id) + if not job_binary_internal: raise ex.NotFoundException(job_binary_internal_id, "JobBinaryInternal id '%s' not found!") - session.delete(b_intrnl) + session.delete(job_binary_internal) diff --git a/savanna/db/sqlalchemy/models.py b/savanna/db/sqlalchemy/models.py index 0c22a655ac..80384fcb0d 100644 --- a/savanna/db/sqlalchemy/models.py +++ b/savanna/db/sqlalchemy/models.py @@ -85,6 +85,7 @@ class NodeGroup(mb.SavannaBase): id = _id_column() name = sa.Column(sa.String(80), nullable=False) + tenant_id = sa.Column(sa.String(36)) flavor_id = sa.Column(sa.String(36), nullable=False) image_id = sa.Column(sa.String(36)) node_processes = sa.Column(st.JsonListType()) @@ -121,6 +122,7 @@ class Instance(mb.SavannaBase): ) id = _id_column() + tenant_id = sa.Column(sa.String(36)) node_group_id = sa.Column(sa.String(36), sa.ForeignKey('node_groups.id')) instance_id = sa.Column(sa.String(36)) instance_name = sa.Column(sa.String(80), nullable=False) @@ -194,6 +196,7 @@ class TemplatesRelation(mb.SavannaBase): __tablename__ = 'templates_relations' id = _id_column() + tenant_id = sa.Column(sa.String(36)) name = sa.Column(sa.String(80), nullable=False) flavor_id = sa.Column(sa.String(36), nullable=False) image_id = sa.Column(sa.String(36)) diff --git a/savanna/service/instances.py b/savanna/service/instances.py index 9c75ee2dbc..69be532695 100644 --- a/savanna/service/instances.py +++ b/savanna/service/instances.py @@ -106,7 +106,7 @@ def scale_cluster(cluster, node_group_id_map, plugin): _await_networks(instances) - cluster = conductor.cluster_get(context, cluster) + cluster = conductor.cluster_get(ctx, cluster) volumes.attach_to_instances(get_instances(cluster, instance_ids)) diff --git a/savanna/tests/unit/conductor/manager/test_clusters.py b/savanna/tests/unit/conductor/manager/test_clusters.py index ed1702a91d..f589e341ef 100644 --- a/savanna/tests/unit/conductor/manager/test_clusters.py +++ b/savanna/tests/unit/conductor/manager/test_clusters.py @@ -109,6 +109,7 @@ class ClusterTest(test_base.ConductorManagerTestCase): ng.pop("volumes_size") ng.pop("volumes_per_node") ng.pop("floating_ip_pool") + ng.pop("tenant_id") self.assertListEqual(SAMPLE_CLUSTER["node_groups"], cl_db_obj["node_groups"]) @@ -212,6 +213,7 @@ class ClusterTest(test_base.ConductorManagerTestCase): if ng["id"] != ng_id: continue + ng.pop('tenant_id') self.assertEqual(count + 1, ng["count"]) self.assertEqual("additional_vm", ng["instances"][0]["instance_name"]) @@ -225,7 +227,7 @@ class ClusterTest(test_base.ConductorManagerTestCase): instance_id = self._add_instance(ctx, ng_id) - self.api.instance_update(context, instance_id, + self.api.instance_update(ctx, instance_id, {"management_ip": "1.1.1.1"}) cluster_db_obj = self.api.cluster_get(ctx, _id) diff --git a/savanna/tests/unit/conductor/manager/test_templates.py b/savanna/tests/unit/conductor/manager/test_templates.py index 73926446c4..287f4da1a0 100644 --- a/savanna/tests/unit/conductor/manager/test_templates.py +++ b/savanna/tests/unit/conductor/manager/test_templates.py @@ -174,6 +174,7 @@ class ClusterTemplates(test_base.ConductorManagerTestCase): ng.pop("created_at") ng.pop("updated_at") ng.pop("id") + ng.pop("tenant_id") self.assertEqual(ng.pop("cluster_template_id"), clt_db_obj_id) ng.pop("image_id") ng.pop("node_configs")