From 3d0506b69c255571ee6b580f85a0bf6646d7ed21 Mon Sep 17 00:00:00 2001 From: Jim Baker Date: Thu, 2 Mar 2017 21:56:28 -0700 Subject: [PATCH] Variable search for resources now uses resolved variables. Implements a filter addition for vars such that each key-value pair in the vars is searched for a possible match, regardless of associated resource; then chases the associated resources for the specific resource being searched for, thereby implementing the resolution algorithm (but in reverse). Finally all such resources are checked to see if they intersect with respect to their rooting key-value pairs in the original search. Examples: $ GET v1/hosts?vars=openstack_release:juno $ GET v1/hosts?vars=foo:abc,bar:xyz To disable this feature in your request, you can supply 'resolved-values=false', for example: $ GET /v1/hosts?vars=foo:bar,baz:zoo&resolved-value=false Filter values are not currently encoded for JSON; but this will fixed in a subsequent patch that is addressing generalized JSON matching via JSON path. Change-Id: I1d40d734e60b5563dfb01da05ffb6494ed9a919c Closes-bug: 1661226 Closes-bug: 1669493 --- craton/db/sqlalchemy/api.py | 140 ++++++++++++++++++--- craton/tests/functional/test_host_calls.py | 81 ++++++++++++ craton/tests/unit/db/test_devices.py | 32 ++++- 3 files changed, 234 insertions(+), 19 deletions(-) diff --git a/craton/db/sqlalchemy/api.py b/craton/db/sqlalchemy/api.py index 723aa4b..533ffe3 100644 --- a/craton/db/sqlalchemy/api.py +++ b/craton/db/sqlalchemy/api.py @@ -2,6 +2,7 @@ import enum import functools +from operator import attrgetter import sys import uuid @@ -12,6 +13,7 @@ from oslo_db.sqlalchemy import session from oslo_db.sqlalchemy import utils as db_utils from oslo_log import log +from sqlalchemy import or_, sql import sqlalchemy.orm.exc as sa_exc from sqlalchemy.orm import with_polymorphic @@ -154,17 +156,119 @@ def model_query(context, model, *args, **kwargs): model=model, session=session, args=args, **kwargs) -def add_var_filters_to_query(query, filters): - # vars filters are of form ?vars=a:b - query = query.join(models.VariableAssociation) - query = query.join(models.Variable) - var_filters = filters['vars'].split(',') - for filters in var_filters: - k, v = filters.split(':', 1) - query = query.filter_by(key=k) - query = query.filter_by(value=v) +def _generate_or_clauses(kv_pairs): + or_clauses = [] + for k, v in kv_pairs: + or_clauses.append( + ((models.Variable.key == k) & (models.Variable.value == v))) + return or_clauses - return query + +def _matching_resources(query, resource_cls, get_descendants, kv): + # NOTE(jimbaker) The below algorithm works as follows: + # + # 1. Computes the generalized descendants for each k:v var in the query; + # + # 2. Computes their intersection, returning a set of matching + # resources (empty set if conjunction of kv matches does not + # match) + + kv_pairs = list(kv.items()) + matches = dict((kv_pair, set()) for kv_pair in kv_pairs) + + # NOTE(jimbaker) this query can be readily generalized. Some + # options could include: + # + # * Key existence (good for treating vars as if they are labels) + # * JSON path matches on the values + # * Nested queries that use JSON paths for the underlying implementation + # + # But for now, simply find all variables that explicitly match one + # or more key value pairs. + # + # Regardless of any generalization, this means at this point we + # need to construct the disjunction ("or") of all the supplied kv + # pairs. (The next step will then compute the conjunction, but + # with respect to resolution.) + q = query.session.query(models.Variable) + q = q.filter(or_(*_generate_or_clauses(kv_pairs))) + variables = set(q) + for variable in variables: + match = matches[(variable.key, variable.value)] + if isinstance(variable.parent, resource_cls): + match.add(variable.parent) + for descendant in get_descendants(variable.parent): + for level in descendant.resolution_order: + desc_variable = level._variables.get(variable.key) + if desc_variable is not None: + if desc_variable in variables: + match.add(descendant) + break + + # NOTE(jimbaker) For now, we simply match for the conjunction + # ("and") of all the supplied kv pairs we are matching + # against. Generalize as desired with other boolean logic. + _, first_match = matches.popitem() + if matches: + resources = first_match.intersection(*matches.values()) + else: + resources = first_match + return resources + + +def _get_devices(parent): + if isinstance(parent, models.Device): + return parent.descendants + else: + return parent.devices + + +_resource_mapping = { + models.Project: ([], None), + models.Cloud: ([models.Project], attrgetter('clouds')), + models.Region: ([models.Project, models.Cloud], attrgetter('regions')), + models.Cell: ( + [models.Project, models.Cloud, models.Region], + attrgetter('cells')), + models.Device: ( + [models.Project, models.Cloud, models.Region, models.Cell, + models.Device], + _get_devices), +} + + +def matching_resources(query, resource_cls, kv, resolved): + def get_desc(parent): + parent_classes, getter = _resource_mapping[resource_cls] + # NOTE(thomasem): If we're not resolving, there are no descendants + # to process, so return an empty list. + if resolved and any(isinstance(parent, cls) for cls in parent_classes): + return getter(parent) + else: + return [] + return _matching_resources(query, resource_cls, get_desc, kv) + + +def _add_var_filters_to_query(query, model, var_filters, resolved=True): + # vars filters are of form ?vars=a:b[,c:d,...] - the filters in + # this case are intersecting ("and" queries) + kv = dict(pairing.split(':', 1) for pairing in var_filters) + resource_ids = set( + resource.id + for resource in matching_resources(query, model, kv, resolved) + ) + if not resource_ids: + # short circuit; this also avoids SQLAlchemy reporting that it is + # working with an empty in clause + return query.filter(sql.false()) + return query.filter(model.id.in_(resource_ids)) + + +def add_var_filters_to_query(query, model, filters): + var_filters = filters['vars'].split(',') + resolved = bool(filters.get('resolved-values')) + return _add_var_filters_to_query(query, model, var_filters, + resolved=resolved) def get_user_info(context, username): @@ -333,7 +437,7 @@ def cells_get_all(context, filters, pagination_params): if "name" in filters: query = query.filter_by(name=filters["name"]) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Cell, filters) return _paginate(context, query, models.Cell, session, filters, pagination_params) @@ -392,7 +496,7 @@ def regions_get_all(context, filters, pagination_params): session=session) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Region, filters) if "cloud_id" in filters: query = query.filter_by(cloud_id=filters["cloud_id"]) @@ -464,7 +568,7 @@ def clouds_get_all(context, filters, pagination_params): session=session) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Cloud, filters) return _paginate(context, query, models.Cloud, session, filters, pagination_params) @@ -558,7 +662,7 @@ def hosts_get_all(context, filters, pagination_params): query = query.filter(models.Device.related_labels.any( models.Label.label == filters["label"])) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Device, filters) return _paginate(context, query, models.Host, session, filters, pagination_params) @@ -642,7 +746,7 @@ def projects_get_all(context, filters, pagination_params): session = get_session() query = model_query(context, models.Project, session=session) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Project, filters) return _paginate(context, query, models.Project, session, filters, pagination_params) @@ -653,7 +757,7 @@ def projects_get_by_name(context, project_name, filters, pagination_params): query = model_query(context, models.Project) query = query.filter(models.Project.name.like(project_name)) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Project, filters) try: return _paginate(context, query, models.Project, session, filters, pagination_params) @@ -786,7 +890,7 @@ def networks_get_all(context, filters, pagination_params): if "name" in filters: query = query.filter_by(name=filters["name"]) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Network, filters) return _paginate(context, query, models.Network, session, filters, pagination_params) @@ -863,7 +967,7 @@ def network_devices_get_all(context, filters, pagination_params): if "device_type" in filters: query = query.filter_by(device_type=filters["device_type"]) if "vars" in filters: - query = add_var_filters_to_query(query, filters) + query = add_var_filters_to_query(query, models.Device, filters) return _paginate(context, query, models.Device, session, filters, pagination_params) diff --git a/craton/tests/functional/test_host_calls.py b/craton/tests/functional/test_host_calls.py index a380025..e35fcc2 100644 --- a/craton/tests/functional/test_host_calls.py +++ b/craton/tests/functional/test_host_calls.py @@ -320,6 +320,87 @@ class APIV1HostTest(DeviceTestBase, APIV1ResourceWithVariablesTestCase): self.assertEqual(sorted([host['id'] for host in test_hosts]), sorted([host['id'] for host in hosts])) + def test_host_get_all_vars_filter_resolved_region(self): + region_vars = {'foo': 'bar'} + region = self.create_region(name='region-2', variables=region_vars) + host_vars = {'baz': 'zoo'} + self.create_host('host1', 'server', '192.168.1.1', **host_vars) + host2 = self.create_host('host2', 'server', '192.168.1.2', + region=region, **host_vars) + url = self.url + '/v1/hosts' + + resp = self.get(url, vars="foo:bar,baz:zoo") + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host2['id'], hosts[0]['id']) + + def test_host_get_all_vars_filter_resolved_region_and_host(self): + region_vars = {'foo': 'bar'} + region = self.create_region(name='region-2', variables=region_vars) + host_vars = {'baz': 'zoo'} + host1 = self.create_host('host1', 'server', '192.168.1.1', + **region_vars) + host2 = self.create_host('host2', 'server', '192.168.1.2', + region=region, **host_vars) + url = self.url + '/v1/hosts' + + resp = self.get(url, vars='foo:bar') + hosts = resp.json()['hosts'] + self.assertEqual(2, len(hosts)) + self.assertListEqual(sorted([host1['id'], host2['id']]), + sorted([host['id'] for host in hosts])) + + def test_host_get_all_vars_filter_resolved_region_child_override(self): + region_vars = {'foo': 'bar'} + region = self.create_region(name='region-2', variables=region_vars) + host1 = self.create_host('host1', 'server', '192.168.1.1', + region=region, foo='baz') + host2 = self.create_host('host2', 'server', '192.168.1.2', + region=region) + url = self.url + '/v1/hosts' + + resp = self.get(url, vars='foo:baz') + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host1['id'], hosts[0]['id']) + + resp = self.get(url, vars='foo:bar') + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host2['id'], hosts[0]['id']) + + def test_host_get_all_vars_filter_resolved_host_child_override(self): + host1 = self.create_host('host1', 'server', '192.168.1.1', + baz='zoo') + host2 = self.create_host('host2', 'server', '192.168.1.2', + parent_id=host1['id'], baz='boo') + url = self.url + '/v1/hosts' + + resp = self.get(url, vars='baz:zoo') + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host1['id'], hosts[0]['id']) + + resp = self.get(url, vars='baz:boo') + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host2['id'], hosts[0]['id']) + + def test_host_get_all_vars_filter_unresolved(self): + host1 = self.create_host('host1', 'server', '192.168.1.1', + foo='bar', baz='zoo') + self.create_host('host2', 'server', '192.168.1.2', foo='bar') + + # NOTE(thomasem): Unfortunately, we use resolved-values instead of + # resolved_values, so we can't pass this in as kwargs to self.get(...), + # see https://bugs.launchpad.net/craton/+bug/1672880. + url = self.url + '/v1/hosts?resolved-values=false&vars=foo:bar,baz:zoo' + + resp = self.get(url) + hosts = resp.json()['hosts'] + self.assertEqual(1, len(hosts)) + self.assertEqual(host1['id'], hosts[0]['id']) + def test_host_delete(self): host = self.create_host('host1', 'server', '192.168.1.1') url = self.url + '/v1/hosts/{}'.format(host['id']) diff --git a/craton/tests/unit/db/test_devices.py b/craton/tests/unit/db/test_devices.py index bcb8752..bcd9b0a 100644 --- a/craton/tests/unit/db/test_devices.py +++ b/craton/tests/unit/db/test_devices.py @@ -538,7 +538,7 @@ class HostsDBTestCase(BaseDevicesDBTestCase): self.context, "hosts", host_id, variables ) filters = { - "region_id": "region_1", + "region_id": 1, "vars": "key1:value5", } res, _ = dbapi.hosts_get_all(self.context, filters, @@ -672,3 +672,33 @@ class HostsDBTestCase(BaseDevicesDBTestCase): 'parent_id': grandchild.id, } ) + + def test_hosts_get_all_with_resolved_var_filters(self): + project_id = self.make_project('project_1', foo='P1', zoo='P2') + cloud_id = self.make_cloud(project_id, 'cloud_1') + region_id = self.make_region( + project_id, cloud_id, 'region_1', foo='R1') + switch_id = self.make_network_device( + project_id, cloud_id, region_id, + 'switch1.example.com', IPAddress('10.1.2.101'), 'switch', + zoo='S1', bar='S2') + self.make_host( + project_id, cloud_id, region_id, + 'www.example.xyz', IPAddress(u'10.1.2.101'), 'server', + parent_id=switch_id, + key1="value1", key2="value2") + self.make_host( + project_id, cloud_id, region_id, + 'www2.example.xyz', IPAddress(u'10.1.2.102'), 'server', + parent_id=switch_id, + key1="value-will-not-match", key2="value2") + + filters = { + "region_id": 1, + "vars": "key1:value1,zoo:S1,foo:R1", + "resolved-values": True, + } + res, _ = dbapi.hosts_get_all( + self.context, filters, default_pagination) + self.assertEqual(len(res), 1) + self.assertEqual(res[0].name, 'www.example.xyz')