Refactor validation of arguments to 'repeat' intrinsic function

Separate validation of the structure of the function call itself from
validation of the arguments passed. The former remains in the
initialisation method, while the latter moves to the validate() method.
This is consistent with how other intrinsic functions are instantiated,
and should hopefully lead to less confusion of the kind that resulted in
bug 1553306.

Change-Id: Ie20b79c33bded82db3befd3ca25fa76ebed3d417
Related-Bug: 1553306
This commit is contained in:
Zane Bitter 2016-03-04 14:20:01 -05:00
parent c5b214e520
commit 5c14291eef
2 changed files with 33 additions and 25 deletions

View File

@ -522,34 +522,39 @@ class Repeat(function.Function):
def __init__(self, stack, fn_name, args):
super(Repeat, self).__init__(stack, fn_name, args)
self._for_each, self._template = self._parse_args()
def _parse_args(self):
if not isinstance(self.args, collections.Mapping):
raise TypeError(_('Arguments to "%s" must be a map') %
self.fn_name)
# We don't check for invalid keys appearing here, which is wrong but
# it's probably too late to change
try:
for_each = self.args['for_each']
template = self.args['template']
except (KeyError, TypeError):
self._for_each = self.args['for_each']
self._template = self.args['template']
except KeyError:
example = ('''repeat:
template: This is %var%
for_each:
%var%: ['a', 'b', 'c']''')
raise KeyError(_('"repeat" syntax should be %s') %
example)
raise KeyError(_('"repeat" syntax should be %s') % example)
if not isinstance(for_each, function.Function):
if not isinstance(for_each, collections.Mapping):
def validate(self):
super(Repeat, self).validate()
if not isinstance(self._for_each, function.Function):
if not isinstance(self._for_each, collections.Mapping):
raise TypeError(_('The "for_each" argument to "%s" must '
'contain a map') % self.fn_name)
for v in six.itervalues(for_each):
if not isinstance(v, (list, function.Function)):
raise TypeError(_('The values of the "for_each" argument '
'to "%s" must be lists') % self.fn_name)
return for_each, template
if not all(self._valid_list(v) for v in self._for_each.values()):
raise TypeError(_('The values of the "for_each" argument '
'to "%s" must be lists') % self.fn_name)
@staticmethod
def _valid_list(arg):
return (isinstance(arg, (collections.Sequence,
function.Function)) and
not isinstance(arg, six.string_types))
def _do_replacement(self, keys, values, template):
if isinstance(template, six.string_types):
@ -566,16 +571,15 @@ class Repeat(function.Function):
def result(self):
for_each = function.resolve(self._for_each)
keys = list(six.iterkeys(for_each))
lists = [for_each[key] for key in keys]
if not all(isinstance(l, list) for l in lists):
if not all(self._valid_list(l) for l in for_each.values()):
raise TypeError(_('The values of the "for_each" argument to '
'"%s" must be lists') % self.fn_name)
template = function.resolve(self._template)
return [self._do_replacement(keys, items, template)
for items in itertools.product(*lists)]
keys, lists = six.moves.zip(*for_each.items())
return [self._do_replacement(keys, replacements, template)
for replacements in itertools.product(*lists)]
class Digest(function.Function):

View File

@ -917,11 +917,6 @@ class HOTemplateTest(common.HeatTestCase):
'foreach': {'%var%': ['a', 'b', 'c']}}}
self.assertRaises(KeyError, self.resolve, snippet, tmpl)
# for_each is not a map
snippet = {'repeat': {'template': 'this is %var%',
'for_each': '%var%'}}
self.assertRaises(TypeError, self.resolve, snippet, tmpl)
# value given to for_each entry is not a list
snippet = {'repeat': {'template': 'this is %var%',
'for_each': {'%var%': 'a'}}}
@ -932,6 +927,15 @@ class HOTemplateTest(common.HeatTestCase):
'for_each': {'%var%': ['a', 'b', 'c']}}}
self.assertRaises(KeyError, self.resolve, snippet, tmpl)
def test_repeat_bad_arg_type(self):
tmpl = template.Template(hot_kilo_tpl_empty)
# for_each is not a map
snippet = {'repeat': {'template': 'this is %var%',
'for_each': '%var%'}}
repeat = tmpl.parse(None, snippet)
self.assertRaises(TypeError, function.validate, repeat)
def test_digest(self):
snippet = {'digest': ['md5', 'foobar']}
snippet_resolved = '3858f62230ac3c915f300c664312c63f'