Make changes to groupBy() backward-compatible

The behaviour of the groupBy() function was fixed in 1.1.2 by
3fb9178401 to pass only the list of values
to be aggregated to the aggregator function, instead of passing both the
key and the list of values (and expecting both the key and the
aggregated value to be returned). This fix was incompatible with
existing expressions that used the aggregator argument to groupBy().

In the event of an error, fall back trying the previous syntax and see
if something plausible gets returned. If not, re-raise the original error.
This should mean that pre-existing expressions will continue to work,
while most outright bogus expressions should fail in a manner consistent
with the correct definition of the function.

Change-Id: Ic6c54be4ed99003fe56cf1a5329f3f1d84fd43c8
Closes-Bug: #1750032
This commit is contained in:
Zane Bitter 2018-02-19 17:00:22 -05:00
parent dc21a823bd
commit 74cb81b280
3 changed files with 138 additions and 48 deletions

View File

@ -86,7 +86,8 @@ def create_context(data=utils.NO_VALUE, context=None, system=True,
math=True, collections=True, queries=True,
regex=True, branching=True,
no_sets=False, finalizer=None, delegates=False,
convention=None, datetime=True, yaqlized=True):
convention=None, datetime=True, yaqlized=True,
group_by_agg_fallback=True):
context = _setup_context(data, context, finalizer, convention)
if system:
@ -104,7 +105,7 @@ def create_context(data=utils.NO_VALUE, context=None, system=True,
if collections:
std_collections.register(context, no_sets)
if queries:
std_queries.register(context)
std_queries.register(context, group_by_agg_fallback)
if regex:
std_regex.register(context)
if branching:

View File

@ -15,10 +15,17 @@
Queries module.
"""
# Get python standard library collections module instead of
# yaql.standard_library.collections
from __future__ import absolute_import
import collections
import itertools
import sys
import six
from yaql.language import exceptions
from yaql.language import specs
from yaql.language import utils
from yaql.language import yaqltypes
@ -844,53 +851,111 @@ def then_by_descending(collection, selector, context):
return collection
@specs.parameter('collection', yaqltypes.Iterable())
@specs.parameter('key_selector', yaqltypes.Lambda())
@specs.parameter('value_selector', yaqltypes.Lambda())
@specs.parameter('aggregator', yaqltypes.Lambda())
@specs.method
def group_by(engine, collection, key_selector, value_selector=None,
aggregator=None):
""":yaql:groupBy
class GroupAggregator(object):
"""A function to aggregate the members of a group found by group_by().
Returns a collection grouped by keySelector with applied valueSelector as
values. Returns a list of pairs where the first value is a result value
of keySelector and the second is a list of values which have common
keySelector return value.
The user-specified function is provided at creation. It is assumed to
accept the group value list as an argument and return an aggregated value.
:signature: collection.groupBy(keySelector, valueSelector => null,
aggregator => null)
:receiverArg collection: input collection
:argType collection: iterable
:arg keySelector: function to be applied to every collection element.
Values are grouped by return value of this function
:argType keySelector: lambda
:arg valueSelector: function to be applied to every collection element to
put it under appropriate group. null by default, which means return
element itself
:argType valueSelector: lambda
:arg aggregator: function to aggregate value within each group. null by
default, which means no function to be evaluated on groups
:argType aggregator: lambda
:returnType: list
.. code::
yaql> [["a", 1], ["b", 2], ["c", 1], ["d", 2]].groupBy($[1], $[0])
[[1, ["a", "c"]], [2, ["b", "d"]]]
yaql> [["a", 1], ["b", 2], ["c", 1]].groupBy($[1], $[0], $.sum())
[[1, "ac"], [2, "b"]]
However, on error we will (optionally) fall back to the pre-1.1.1 behaviour
of assuming that the function expects a tuple containing both the key and
the value list, and similarly returns a tuple of the key and value. This
can still give the wrong results if the first group(s) to be aggregated
have value lists of length exactly 2, but for the most part is backwards
compatible to 1.1.1.
"""
groups = {}
if aggregator is None:
new_aggregator = lambda x: x
else:
new_aggregator = lambda x: (x[0], aggregator(x[1]))
for t in collection:
value = t if value_selector is None else value_selector(t)
groups.setdefault(key_selector(t), []).append(value)
utils.limit_memory_usage(engine, (1, groups))
return select(six.iteritems(groups), new_aggregator)
def __init__(self, aggregator_func=None, allow_fallback=True):
self.aggregator = aggregator_func
self.allow_fallback = allow_fallback
self._failure_info = None
def __call__(self, group_item):
if self.aggregator is None:
return group_item
if self._failure_info is None:
key, value_list = group_item
try:
result = self.aggregator(value_list)
except (exceptions.NoMatchingMethodException,
exceptions.NoMatchingFunctionException,
IndexError):
self._failure_info = sys.exc_info()
else:
if not (len(value_list) == 2 and
isinstance(result, collections.Sequence) and
not isinstance(result, six.string_types) and
len(result) == 2 and
result[0] == value_list[0]):
# We are not dealing with (correct) version 1.1.1 syntax,
# so don't bother trying to fall back if there's an error
# with a later group.
self.allow_fallback = False
return key, result
if self.allow_fallback:
# Fall back to assuming version 1.1.1 syntax.
try:
result = self.aggregator(group_item)
if len(result) == 2:
return result
except Exception:
pass
# If we are unable to successfully fall back, re-raise the first
# exception encountered to help the user debug in the new style.
six.reraise(*self._failure_info)
def group_by_function(allow_aggregator_fallback):
@specs.parameter('collection', yaqltypes.Iterable())
@specs.parameter('key_selector', yaqltypes.Lambda())
@specs.parameter('value_selector', yaqltypes.Lambda())
@specs.parameter('aggregator', yaqltypes.Lambda())
@specs.method
def group_by(engine, collection, key_selector, value_selector=None,
aggregator=None):
""":yaql:groupBy
Returns a collection grouped by keySelector with applied valueSelector
as values. Returns a list of pairs where the first value is a result
value of keySelector and the second is a list of values which have
common keySelector return value.
:signature: collection.groupBy(keySelector, valueSelector => null,
aggregator => null)
:receiverArg collection: input collection
:argType collection: iterable
:arg keySelector: function to be applied to every collection element.
Values are grouped by return value of this function
:argType keySelector: lambda
:arg valueSelector: function to be applied to every collection element
to put it under appropriate group. null by default, which means
return element itself
:argType valueSelector: lambda
:arg aggregator: function to aggregate value within each group. null by
default, which means no function to be evaluated on groups
:argType aggregator: lambda
:returnType: list
.. code::
yaql> [["a", 1], ["b", 2], ["c", 1], ["d", 2]].groupBy($[1], $[0])
[[1, ["a", "c"]], [2, ["b", "d"]]]
yaql> [["a", 1], ["b", 2], ["c", 1]].groupBy($[1], $[0], $.sum())
[[1, "ac"], [2, "b"]]
"""
groups = {}
new_aggregator = GroupAggregator(aggregator, allow_aggregator_fallback)
for t in collection:
value = t if value_selector is None else value_selector(t)
groups.setdefault(key_selector(t), []).append(value)
utils.limit_memory_usage(engine, (1, groups))
return select(six.iteritems(groups), new_aggregator)
return group_by
@specs.method
@ -1680,7 +1745,7 @@ def default_if_empty(engine, collection, default):
return default
def register(context):
def register(context, allow_group_by_agg_fallback=True):
context.register_function(where)
context.register_function(where, name='filter')
context.register_function(select)
@ -1711,7 +1776,7 @@ def register(context):
context.register_function(order_by_descending)
context.register_function(then_by)
context.register_function(then_by_descending)
context.register_function(group_by)
context.register_function(group_by_function(allow_group_by_agg_fallback))
context.register_function(join)
context.register_function(zip_)
context.register_function(zip_longest)

View File

@ -226,6 +226,30 @@ class TestQueries(yaql.tests.TestCase):
'groupBy($[1], aggregator => $.sum())',
data=data))
def test_group_by_old_syntax(self):
# Test the syntax used in 1.1.1 and earlier, where the aggregator
# function was passed the key as well as the value list, and returned
# the key along with the aggregated value. This ensures backward
# compatibility with existing expressions.
data = {'a': 1, 'b': 2, 'c': 1, 'd': 3, 'e': 2}
self.assertItemsEqual(
[[1, 'ac'], [2, 'be'], [3, 'd']],
self.eval('$.items().orderBy($[0]).'
'groupBy($[1], $[0], [$[0], $[1].sum()])', data=data))
self.assertItemsEqual(
[[1, ['a', 1, 'c', 1]], [2, ['b', 2, 'e', 2]], [3, ['d', 3]]],
self.eval('$.items().orderBy($[0]).'
'groupBy($[1],, [$[0], $[1].sum()])',
data=data))
self.assertItemsEqual(
[[1, ['a', 1, 'c', 1]], [2, ['b', 2, 'e', 2]], [3, ['d', 3]]],
self.eval('$.items().orderBy($[0]).'
'groupBy($[1], aggregator => [$[0], $[1].sum()])',
data=data))
def test_join(self):
self.assertEqual(
[[2, 1], [3, 1], [3, 2], [4, 1], [4, 2], [4, 3]],