From 74cb81b28053ebe41e1e0031cfcb30f3039f0857 Mon Sep 17 00:00:00 2001 From: Zane Bitter Date: Mon, 19 Feb 2018 17:00:22 -0500 Subject: [PATCH] Make changes to groupBy() backward-compatible The behaviour of the groupBy() function was fixed in 1.1.2 by 3fb91784018de335440b01b3b069fe45dc53e025 to pass only the list of values to be aggregated to the aggregator function, instead of passing both the key and the list of values (and expecting both the key and the aggregated value to be returned). This fix was incompatible with existing expressions that used the aggregator argument to groupBy(). In the event of an error, fall back trying the previous syntax and see if something plausible gets returned. If not, re-raise the original error. This should mean that pre-existing expressions will continue to work, while most outright bogus expressions should fail in a manner consistent with the correct definition of the function. Change-Id: Ic6c54be4ed99003fe56cf1a5329f3f1d84fd43c8 Closes-Bug: #1750032 --- yaql/__init__.py | 5 +- yaql/standard_library/queries.py | 157 ++++++++++++++++++++++--------- yaql/tests/test_queries.py | 24 +++++ 3 files changed, 138 insertions(+), 48 deletions(-) diff --git a/yaql/__init__.py b/yaql/__init__.py index 5284ebf..eddd54f 100644 --- a/yaql/__init__.py +++ b/yaql/__init__.py @@ -86,7 +86,8 @@ def create_context(data=utils.NO_VALUE, context=None, system=True, math=True, collections=True, queries=True, regex=True, branching=True, no_sets=False, finalizer=None, delegates=False, - convention=None, datetime=True, yaqlized=True): + convention=None, datetime=True, yaqlized=True, + group_by_agg_fallback=True): context = _setup_context(data, context, finalizer, convention) if system: @@ -104,7 +105,7 @@ def create_context(data=utils.NO_VALUE, context=None, system=True, if collections: std_collections.register(context, no_sets) if queries: - std_queries.register(context) + std_queries.register(context, group_by_agg_fallback) if regex: std_regex.register(context) if branching: diff --git a/yaql/standard_library/queries.py b/yaql/standard_library/queries.py index 49bd10d..4df2966 100644 --- a/yaql/standard_library/queries.py +++ b/yaql/standard_library/queries.py @@ -15,10 +15,17 @@ Queries module. """ +# Get python standard library collections module instead of +# yaql.standard_library.collections +from __future__ import absolute_import + +import collections import itertools +import sys import six +from yaql.language import exceptions from yaql.language import specs from yaql.language import utils from yaql.language import yaqltypes @@ -844,53 +851,111 @@ def then_by_descending(collection, selector, context): return collection -@specs.parameter('collection', yaqltypes.Iterable()) -@specs.parameter('key_selector', yaqltypes.Lambda()) -@specs.parameter('value_selector', yaqltypes.Lambda()) -@specs.parameter('aggregator', yaqltypes.Lambda()) -@specs.method -def group_by(engine, collection, key_selector, value_selector=None, - aggregator=None): - """:yaql:groupBy +class GroupAggregator(object): + """A function to aggregate the members of a group found by group_by(). - Returns a collection grouped by keySelector with applied valueSelector as - values. Returns a list of pairs where the first value is a result value - of keySelector and the second is a list of values which have common - keySelector return value. + The user-specified function is provided at creation. It is assumed to + accept the group value list as an argument and return an aggregated value. - :signature: collection.groupBy(keySelector, valueSelector => null, - aggregator => null) - :receiverArg collection: input collection - :argType collection: iterable - :arg keySelector: function to be applied to every collection element. - Values are grouped by return value of this function - :argType keySelector: lambda - :arg valueSelector: function to be applied to every collection element to - put it under appropriate group. null by default, which means return - element itself - :argType valueSelector: lambda - :arg aggregator: function to aggregate value within each group. null by - default, which means no function to be evaluated on groups - :argType aggregator: lambda - :returnType: list - - .. code:: - - yaql> [["a", 1], ["b", 2], ["c", 1], ["d", 2]].groupBy($[1], $[0]) - [[1, ["a", "c"]], [2, ["b", "d"]]] - yaql> [["a", 1], ["b", 2], ["c", 1]].groupBy($[1], $[0], $.sum()) - [[1, "ac"], [2, "b"]] + However, on error we will (optionally) fall back to the pre-1.1.1 behaviour + of assuming that the function expects a tuple containing both the key and + the value list, and similarly returns a tuple of the key and value. This + can still give the wrong results if the first group(s) to be aggregated + have value lists of length exactly 2, but for the most part is backwards + compatible to 1.1.1. """ - groups = {} - if aggregator is None: - new_aggregator = lambda x: x - else: - new_aggregator = lambda x: (x[0], aggregator(x[1])) - for t in collection: - value = t if value_selector is None else value_selector(t) - groups.setdefault(key_selector(t), []).append(value) - utils.limit_memory_usage(engine, (1, groups)) - return select(six.iteritems(groups), new_aggregator) + + def __init__(self, aggregator_func=None, allow_fallback=True): + self.aggregator = aggregator_func + self.allow_fallback = allow_fallback + self._failure_info = None + + def __call__(self, group_item): + if self.aggregator is None: + return group_item + + if self._failure_info is None: + key, value_list = group_item + try: + result = self.aggregator(value_list) + except (exceptions.NoMatchingMethodException, + exceptions.NoMatchingFunctionException, + IndexError): + self._failure_info = sys.exc_info() + else: + if not (len(value_list) == 2 and + isinstance(result, collections.Sequence) and + not isinstance(result, six.string_types) and + len(result) == 2 and + result[0] == value_list[0]): + # We are not dealing with (correct) version 1.1.1 syntax, + # so don't bother trying to fall back if there's an error + # with a later group. + self.allow_fallback = False + + return key, result + + if self.allow_fallback: + # Fall back to assuming version 1.1.1 syntax. + try: + result = self.aggregator(group_item) + if len(result) == 2: + return result + except Exception: + pass + + # If we are unable to successfully fall back, re-raise the first + # exception encountered to help the user debug in the new style. + six.reraise(*self._failure_info) + + +def group_by_function(allow_aggregator_fallback): + @specs.parameter('collection', yaqltypes.Iterable()) + @specs.parameter('key_selector', yaqltypes.Lambda()) + @specs.parameter('value_selector', yaqltypes.Lambda()) + @specs.parameter('aggregator', yaqltypes.Lambda()) + @specs.method + def group_by(engine, collection, key_selector, value_selector=None, + aggregator=None): + """:yaql:groupBy + + Returns a collection grouped by keySelector with applied valueSelector + as values. Returns a list of pairs where the first value is a result + value of keySelector and the second is a list of values which have + common keySelector return value. + + :signature: collection.groupBy(keySelector, valueSelector => null, + aggregator => null) + :receiverArg collection: input collection + :argType collection: iterable + :arg keySelector: function to be applied to every collection element. + Values are grouped by return value of this function + :argType keySelector: lambda + :arg valueSelector: function to be applied to every collection element + to put it under appropriate group. null by default, which means + return element itself + :argType valueSelector: lambda + :arg aggregator: function to aggregate value within each group. null by + default, which means no function to be evaluated on groups + :argType aggregator: lambda + :returnType: list + + .. code:: + + yaql> [["a", 1], ["b", 2], ["c", 1], ["d", 2]].groupBy($[1], $[0]) + [[1, ["a", "c"]], [2, ["b", "d"]]] + yaql> [["a", 1], ["b", 2], ["c", 1]].groupBy($[1], $[0], $.sum()) + [[1, "ac"], [2, "b"]] + """ + groups = {} + new_aggregator = GroupAggregator(aggregator, allow_aggregator_fallback) + for t in collection: + value = t if value_selector is None else value_selector(t) + groups.setdefault(key_selector(t), []).append(value) + utils.limit_memory_usage(engine, (1, groups)) + return select(six.iteritems(groups), new_aggregator) + + return group_by @specs.method @@ -1680,7 +1745,7 @@ def default_if_empty(engine, collection, default): return default -def register(context): +def register(context, allow_group_by_agg_fallback=True): context.register_function(where) context.register_function(where, name='filter') context.register_function(select) @@ -1711,7 +1776,7 @@ def register(context): context.register_function(order_by_descending) context.register_function(then_by) context.register_function(then_by_descending) - context.register_function(group_by) + context.register_function(group_by_function(allow_group_by_agg_fallback)) context.register_function(join) context.register_function(zip_) context.register_function(zip_longest) diff --git a/yaql/tests/test_queries.py b/yaql/tests/test_queries.py index da65ec3..bc2cbab 100644 --- a/yaql/tests/test_queries.py +++ b/yaql/tests/test_queries.py @@ -226,6 +226,30 @@ class TestQueries(yaql.tests.TestCase): 'groupBy($[1], aggregator => $.sum())', data=data)) + def test_group_by_old_syntax(self): + # Test the syntax used in 1.1.1 and earlier, where the aggregator + # function was passed the key as well as the value list, and returned + # the key along with the aggregated value. This ensures backward + # compatibility with existing expressions. + data = {'a': 1, 'b': 2, 'c': 1, 'd': 3, 'e': 2} + + self.assertItemsEqual( + [[1, 'ac'], [2, 'be'], [3, 'd']], + self.eval('$.items().orderBy($[0]).' + 'groupBy($[1], $[0], [$[0], $[1].sum()])', data=data)) + + self.assertItemsEqual( + [[1, ['a', 1, 'c', 1]], [2, ['b', 2, 'e', 2]], [3, ['d', 3]]], + self.eval('$.items().orderBy($[0]).' + 'groupBy($[1],, [$[0], $[1].sum()])', + data=data)) + + self.assertItemsEqual( + [[1, ['a', 1, 'c', 1]], [2, ['b', 2, 'e', 2]], [3, ['d', 3]]], + self.eval('$.items().orderBy($[0]).' + 'groupBy($[1], aggregator => [$[0], $[1].sum()])', + data=data)) + def test_join(self): self.assertEqual( [[2, 1], [3, 1], [3, 2], [4, 1], [4, 2], [4, 3]],