diff --git a/craton/db/sqlalchemy/api.py b/craton/db/sqlalchemy/api.py index 723aa4b..533ffe3 100644 --- a/craton/db/sqlalchemy/api.py +++ b/craton/db/sqlalchemy/api.py @@ -2,6 +2,7 @@ import enum import functools +from operator import attrgetter import sys import uuid @@ -12,6 +13,7 @@ from oslo_db.sqlalchemy import session from oslo_db.sqlalchemy import utils as db_utils from oslo_log import log +from sqlalchemy import or_, sql import sqlalchemy.orm.exc as sa_exc from sqlalchemy.orm import with_polymorphic @@ -154,17 +156,119 @@ def model_query(context, model, *args, **kwargs): model=model, session=session, args=args, **kwargs) -def add_var_filters_to_query(query, filters): - # vars filters are of form ?vars=a:b - query = query.join(models.VariableAssociation) - query = query.join(models.Variable) - var_filters = filters['vars'].split(',') - for filters in var_filters: - k, v = filters.split(':', 1) - query = query.filter_by(key=k) - query = query.filter_by(value=v) +def _generate_or_clauses(kv_pairs): + or_clauses = [] + for k, v in kv_pairs: + or_clauses.append( + ((models.Variable.key == k) & (models.Variable.value == v))) + return or_clauses - return query + +def _matching_resources(query, resource_cls, get_descendants, kv): + # NOTE(jimbaker) The below algorithm works as follows: + # + # 1. Computes the generalized descendants for each k:v var in the query; + # + # 2. Computes their intersection, returning a set of matching + # resources (empty set if conjunction of kv matches does not + # match) + + kv_pairs = list(kv.items()) + matches = dict((kv_pair, set()) for kv_pair in kv_pairs) + + # NOTE(jimbaker) this query can be readily generalized. Some + # options could include: + # + # * Key existence (good for treating vars as if they are labels) + # * JSON path matches on the values + # * Nested queries that use JSON paths for the underlying implementation + # + # But for now, simply find all variables that explicitly match one + # or more key value pairs. + # + # Regardless of any generalization, this means at this point we + # need to construct the disjunction ("or") of all the supplied kv + # pairs. (The next step will then compute the conjunction, but + # with respect to resolution.) + q = query.session.query(models.Variable) + q = q.filter(or_(*_generate_or_clauses(kv_pairs))) + variables = set(q) + for variable in variables: + match = matches[(variable.key, variable.value)] + if isinstance(variable.parent, resource_cls): + match.add(variable.parent) + for descendant in get_descendants(variable.parent): + for level in descendant.resolution_order: + desc_variable = level._variables.get(variable.key) + if desc_variable is not None: + if desc_variable in variables: + match.add(descendant) + break + + # NOTE(jimbaker) For now, we simply match for the conjunction + # ("and") of all the supplied kv pairs we are matching + # against. Generalize as desired with other boolean logic. + _, first_match = matches.popitem() + if matches: + resources = first_match.intersection(*matches.values()) + else: + resources = first_match + return resources + + +def _get_devices(parent): + if isinstance(parent, models.Device): + return parent.descendants + else: + return parent.devices + + +_resource_mapping = { + models.Project: ([], None), + models.Cloud: ([models.Project], attrgetter('clouds')), + models.Region: ([models.Project, models.Cloud], attrgetter('regions')), + models.Cell: ( + [models.Project, models.Cloud, models.Region], + attrgetter('cells')), + models.Device: ( + [models.Project, models.Cloud, models.Region, models.Cell, + models.Device], + _get_devices), +} + + +def matching_resources(query, resource_cls, kv, resolved): + def get_desc(parent): + parent_classes, getter = _resource_mapping[resource_cls] + # NOTE(thomasem): If we're not resolving, there are no descendants + # to process, so return an empty list. + if resolved and any(isinstance(parent, cls) for cls in parent_classes): + return getter(parent) + else: + return [] + return _matching_resources(query, resource_cls, get_desc, kv) + + +def _add_var_filters_to_query(query, model, var_filters, resolved=True): + # vars filters are of form ?vars=a:b[,c:d,...] - the filters in + # this case are intersecting ("and" queries) + kv = dict(pairing.split(':', 1) for pairing in var_filters) + resource_ids = set( + resource.id + for resource in matching_resources(query, model, kv, resolved) + ) + if not resource_ids: + # short circuit; this also avoids SQLAlchemy reporting that it is + # working with an empty in clause + return query.filter(sql.false()) + return query.filter(model.id.in_(resource_ids)) + + +def add_var_filters_to_query(query, model, filters): + var_filters = filters['vars'].split(',') + resolved = bool(filters.get('resolved-values')) + return _add_var_filters_to_query(query, model, var_filters, + resolved=resolved) def get_user_info(context, username): @@ -333,7 +437,7 @@ def cells_get_all(context, filters, pagination_params): if "name" in filters: query = query.filter_by(name=filters["name"]) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Cell, filters) return _paginate(context, query, models.Cell, session, filters, pagination_params) @@ -392,7 +496,7 @@ def regions_get_all(context, filters, pagination_params): session=session) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Region, filters) if "cloud_id" in filters: query = query.filter_by(cloud_id=filters["cloud_id"]) @@ -464,7 +568,7 @@ def clouds_get_all(context, filters, pagination_params): session=session) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Cloud, filters) return _paginate(context, query, models.Cloud, session, filters, pagination_params) @@ -558,7 +662,7 @@ def hosts_get_all(context, filters, pagination_params): query = query.filter(models.Device.related_labels.any( models.Label.label == filters["label"])) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Device, filters) return _paginate(context, query, models.Host, session, filters, pagination_params) @@ -642,7 +746,7 @@ def projects_get_all(context, filters, pagination_params): session = get_session() query = model_query(context, models.Project, session=session) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Project, filters) return _paginate(context, query, models.Project, session, filters, pagination_params) @@ -653,7 +757,7 @@ def projects_get_by_name(context, project_name, filters, pagination_params): query = model_query(context, models.Project) query = query.filter(models.Project.name.like(project_name)) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Project, filters) try: return _paginate(context, query, models.Project, session, filters, pagination_params) @@ -786,7 +890,7 @@ def networks_get_all(context, filters, pagination_params): if "name" in filters: query = query.filter_by(name=filters["name"]) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Network, filters) return _paginate(context, query, models.Network, session, filters, pagination_params) @@ -863,7 +967,7 @@ def network_devices_get_all(context, filters, pagination_params): if "device_type" in filters: query = query.filter_by(device_type=filters["device_type"]) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Device, filters) return _paginate(context, query, models.Device, session, filters, pagination_params) diff --git a/craton/tests/functional/test_host_calls.py b/craton/tests/functional/test_host_calls.py index a380025..e35fcc2 100644 --- a/craton/tests/functional/test_host_calls.py +++ b/craton/tests/functional/test_host_calls.py @@ -320,6 +320,87 @@ class APIV1HostTest(DeviceTestBase, APIV1ResourceWithVariablesTestCase): self.assertEqual(sorted([host['id'] for host in test_hosts]), sorted([host['id'] for host in hosts])) + def test_host_get_all_vars_filter_resolved_region(self): + region_vars = {'foo': 'bar'} + region = self.create_region(name='region-2', variables=region_vars) + host_vars = {'baz': 'zoo'} + self.create_host('host1', 'server', '192.168.1.1', **host_vars) + host2 = self.create_host('host2', 'server', '192.168.1.2', + region=region, **host_vars) + url = self.url + '/v1/hosts' + + resp = self.get(url, vars="foo:bar,baz:zoo") + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host2['id'], hosts[0]['id']) + + def test_host_get_all_vars_filter_resolved_region_and_host(self): + region_vars = {'foo': 'bar'} + region = self.create_region(name='region-2', variables=region_vars) + host_vars = {'baz': 'zoo'} + host1 = self.create_host('host1', 'server', '192.168.1.1', + **region_vars) + host2 = self.create_host('host2', 'server', '192.168.1.2', + region=region, **host_vars) + url = self.url + '/v1/hosts' + + resp = self.get(url, vars='foo:bar') + hosts = resp.json()['hosts'] + self.assertEqual(2, len(hosts)) + self.assertListEqual(sorted([host1['id'], host2['id']]), + sorted([host['id'] for host in hosts])) + + def test_host_get_all_vars_filter_resolved_region_child_override(self): + region_vars = {'foo': 'bar'} + region = self.create_region(name='region-2', variables=region_vars) + host1 = self.create_host('host1', 'server', '192.168.1.1', + region=region, foo='baz') + host2 = self.create_host('host2', 'server', '192.168.1.2', + region=region) + url = self.url + '/v1/hosts' + + resp = self.get(url, vars='foo:baz') + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host1['id'], hosts[0]['id']) + + resp = self.get(url, vars='foo:bar') + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host2['id'], hosts[0]['id']) + + def test_host_get_all_vars_filter_resolved_host_child_override(self): + host1 = self.create_host('host1', 'server', '192.168.1.1', + baz='zoo') + host2 = self.create_host('host2', 'server', '192.168.1.2', + parent_id=host1['id'], baz='boo') + url = self.url + '/v1/hosts' + + resp = self.get(url, vars='baz:zoo') + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host1['id'], hosts[0]['id']) + + resp = self.get(url, vars='baz:boo') + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host2['id'], hosts[0]['id']) + + def test_host_get_all_vars_filter_unresolved(self): + host1 = self.create_host('host1', 'server', '192.168.1.1', + foo='bar', baz='zoo') + self.create_host('host2', 'server', '192.168.1.2', foo='bar') + + # NOTE(thomasem): Unfortunately, we use resolved-values instead of + # resolved_values, so we can't pass this in as kwargs to self.get(...), + # see https://bugs.launchpad.net/craton/+bug/1672880. + url = self.url + '/v1/hosts?resolved-values=false&vars=foo:bar,baz:zoo' + + resp = self.get(url) + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host1['id'], hosts[0]['id']) + def test_host_delete(self): host = self.create_host('host1', 'server', '192.168.1.1') url = self.url + '/v1/hosts/{}'.format(host['id']) diff --git a/craton/tests/unit/db/test_devices.py b/craton/tests/unit/db/test_devices.py index bcb8752..bcd9b0a 100644 --- a/craton/tests/unit/db/test_devices.py +++ b/craton/tests/unit/db/test_devices.py @@ -538,7 +538,7 @@ class HostsDBTestCase(BaseDevicesDBTestCase): self.context, "hosts", host_id, variables ) filters = { - "region_id": "region_1", + "region_id": 1, "vars": "key1:value5", } res, _ = dbapi.hosts_get_all(self.context, filters, @@ -672,3 +672,33 @@ class HostsDBTestCase(BaseDevicesDBTestCase): 'parent_id': grandchild.id, } ) + + def test_hosts_get_all_with_resolved_var_filters(self): + project_id = self.make_project('project_1', foo='P1', zoo='P2') + cloud_id = self.make_cloud(project_id, 'cloud_1') + region_id = self.make_region( + project_id, cloud_id, 'region_1', foo='R1') + switch_id = self.make_network_device( + project_id, cloud_id, region_id, + 'switch1.example.com', IPAddress('10.1.2.101'), 'switch', + zoo='S1', bar='S2') + self.make_host( + project_id, cloud_id, region_id, + 'www.example.xyz', IPAddress(u'10.1.2.101'), 'server', + parent_id=switch_id, + key1="value1", key2="value2") + self.make_host( + project_id, cloud_id, region_id, + 'www2.example.xyz', IPAddress(u'10.1.2.102'), 'server', + parent_id=switch_id, + key1="value-will-not-match", key2="value2") + + filters = { + "region_id": 1, + "vars": "key1:value1,zoo:S1,foo:R1", + "resolved-values": True, + } + res, _ = dbapi.hosts_get_all( + self.context, filters, default_pagination) + self.assertEqual(len(res), 1) + self.assertEqual(res[0].name, 'www.example.xyz')