diff --git a/ironic/api/args.py b/ironic/api/args.py new file mode 100644 index 0000000000..4c721f7b19 --- /dev/null +++ b/ironic/api/args.py @@ -0,0 +1,387 @@ +# Copyright 2011-2019 the WSME authors and contributors +# (See https://opendev.org/x/wsme/) +# +# This module is part of WSME and is also released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import cgi +import datetime +import decimal +import json +import logging + +from dateutil import parser as dateparser + +from ironic.api import types as atypes +from ironic.common import exception + +LOG = logging.getLogger(__name__) + + +CONTENT_TYPE = 'application/json' +ACCEPT_CONTENT_TYPES = [ + CONTENT_TYPE, + 'text/javascript', + 'application/javascript' +] +ENUM_TRUE = ('true', 't', 'yes', 'y', 'on', '1') +ENUM_FALSE = ('false', 'f', 'no', 'n', 'off', '0') + + +def fromjson_array(datatype, value): + if not isinstance(value, list): + raise ValueError("Value not a valid list: %s" % value) + return [fromjson(datatype.item_type, item) for item in value] + + +def fromjson_dict(datatype, value): + if not isinstance(value, dict): + raise ValueError("Value not a valid dict: %s" % value) + return dict(( + (fromjson(datatype.key_type, item[0]), + fromjson(datatype.value_type, item[1])) + for item in value.items())) + + +def fromjson_bool(value): + if isinstance(value, (int, bool)): + return bool(value) + if value in ENUM_TRUE: + return True + if value in ENUM_FALSE: + return False + raise ValueError("Value not an unambiguous boolean: %s" % value) + + +def fromjson(datatype, value): + """A generic converter from json base types to python datatype. + + """ + if value is None: + return None + + if isinstance(datatype, atypes.ArrayType): + return fromjson_array(datatype, value) + + if isinstance(datatype, atypes.DictType): + return fromjson_dict(datatype, value) + + if datatype is bytes: + if isinstance(value, (str, int, float)): + return str(value).encode('utf8') + return value + + if datatype is str: + if isinstance(value, bytes): + return value.decode('utf-8') + return value + + if datatype in (int, float): + return datatype(value) + + if datatype is bool: + return fromjson_bool(value) + + if datatype is decimal.Decimal: + return decimal.Decimal(value) + + if datatype is datetime.datetime: + return dateparser.parse(value) + + if atypes.iscomplex(datatype): + return fromjson_complex(datatype, value) + + if atypes.isusertype(datatype): + return datatype.frombasetype(fromjson(datatype.basetype, value)) + + return value + + +def fromjson_complex(datatype, value): + obj = datatype() + attributes = atypes.list_attributes(datatype) + + # Here we check that all the attributes in the value are also defined + # in our type definition, otherwise we raise an Error. + v_keys = set(value.keys()) + a_keys = set(adef.name for adef in attributes) + if not v_keys <= a_keys: + raise exception.UnknownAttribute(None, v_keys - a_keys) + + for attrdef in attributes: + if attrdef.name in value: + try: + val_fromjson = fromjson(attrdef.datatype, + value[attrdef.name]) + except exception.UnknownAttribute as e: + e.add_fieldname(attrdef.name) + raise + if getattr(attrdef, 'readonly', False): + raise exception.InvalidInput(attrdef.name, val_fromjson, + "Cannot set read only field.") + setattr(obj, attrdef.key, val_fromjson) + elif attrdef.mandatory: + raise exception.InvalidInput(attrdef.name, None, + "Mandatory field missing.") + + return atypes.validate_value(datatype, obj) + + +def parse(s, datatypes, bodyarg, encoding='utf8'): + jload = json.load + if not hasattr(s, 'read'): + if isinstance(s, bytes): + s = s.decode(encoding) + jload = json.loads + try: + jdata = jload(s) + except ValueError: + raise exception.ClientSideError("Request is not in valid JSON format") + if bodyarg: + argname = list(datatypes.keys())[0] + try: + kw = {argname: fromjson(datatypes[argname], jdata)} + except ValueError as e: + raise exception.InvalidInput(argname, jdata, e.args[0]) + except exception.UnknownAttribute as e: + # We only know the fieldname at this level, not in the + # called function. We fill in this information here. + e.add_fieldname(argname) + raise + else: + kw = {} + extra_args = [] + if not isinstance(jdata, dict): + raise exception.ClientSideError("Request must be a JSON dict") + for key in jdata: + if key not in datatypes: + extra_args.append(key) + else: + try: + kw[key] = fromjson(datatypes[key], jdata[key]) + except ValueError as e: + raise exception.InvalidInput(key, jdata[key], e.args[0]) + except exception.UnknownAttribute as e: + # We only know the fieldname at this level, not in the + # called function. We fill in this information here. + e.add_fieldname(key) + raise + if extra_args: + raise exception.UnknownArgument(', '.join(extra_args)) + return kw + + +def from_param(datatype, value): + if datatype is datetime.datetime: + return dateparser.parse(value) if value else None + + if datatype is atypes.File: + if isinstance(value, cgi.FieldStorage): + return atypes.File(fieldstorage=value) + return atypes.File(content=value) + + if isinstance(datatype, atypes.UserType): + return datatype.frombasetype( + from_param(datatype.basetype, value)) + + if isinstance(datatype, atypes.ArrayType): + if value is None: + return value + return [ + from_param(datatype.item_type, item) + for item in value + ] + + return datatype(value) if value is not None else None + + +def from_params(datatype, params, path, hit_paths): + if isinstance(datatype, atypes.ArrayType): + return array_from_params(datatype, params, path, hit_paths) + + if isinstance(datatype, atypes.UserType): + return usertype_from_params(datatype, params, path, hit_paths) + + if path in params: + assert not isinstance(datatype, atypes.DictType), \ + 'DictType unsupported' + assert not atypes.iscomplex(datatype) or datatype is atypes.File, \ + 'complex type unsupported' + hit_paths.add(path) + return from_param(datatype, params[path]) + return atypes.Unset + + +def array_from_params(datatype, params, path, hit_paths): + if hasattr(params, 'getall'): + # webob multidict + def getall(params, path): + return params.getall(path) + elif hasattr(params, 'getlist'): + # werkzeug multidict + def getall(params, path): # noqa + return params.getlist(path) + if path in params: + hit_paths.add(path) + return [ + from_param(datatype.item_type, value) + for value in getall(params, path)] + + return atypes.Unset + + +def usertype_from_params(datatype, params, path, hit_paths): + if path in params: + hit_paths.add(path) + value = from_param(datatype.basetype, params[path]) + if value is not atypes.Unset: + return datatype.frombasetype(value) + return atypes.Unset + + +def args_from_args(funcdef, args, kwargs): + newargs = [] + for argdef, arg in zip(funcdef.arguments[:len(args)], args): + try: + newargs.append(from_param(argdef.datatype, arg)) + except Exception as e: + if isinstance(argdef.datatype, atypes.UserType): + datatype_name = argdef.datatype.name + elif isinstance(argdef.datatype, type): + datatype_name = argdef.datatype.__name__ + else: + datatype_name = argdef.datatype.__class__.__name__ + raise exception.InvalidInput( + argdef.name, + arg, + "unable to convert to %(datatype)s. Error: %(error)s" % { + 'datatype': datatype_name, 'error': e}) + newkwargs = {} + for argname, value in kwargs.items(): + newkwargs[argname] = from_param( + funcdef.get_arg(argname).datatype, value + ) + return newargs, newkwargs + + +def args_from_params(funcdef, params): + kw = {} + hit_paths = set() + for argdef in funcdef.arguments: + value = from_params( + argdef.datatype, params, argdef.name, hit_paths) + if value is not atypes.Unset: + kw[argdef.name] = value + paths = set(params.keys()) + unknown_paths = paths - hit_paths + if '__body__' in unknown_paths: + unknown_paths.remove('__body__') + if not funcdef.ignore_extra_args and unknown_paths: + raise exception.UnknownArgument(', '.join(unknown_paths)) + return [], kw + + +def args_from_body(funcdef, body, mimetype): + if funcdef.body_type is not None: + datatypes = {funcdef.arguments[-1].name: funcdef.body_type} + else: + datatypes = dict(((a.name, a.datatype) for a in funcdef.arguments)) + + if not body: + return (), {} + + if mimetype == "application/x-www-form-urlencoded": + # the parameters should have been parsed in params + return (), {} + elif mimetype not in ACCEPT_CONTENT_TYPES: + raise exception.ClientSideError("Unknown mimetype: %s" % mimetype, + status_code=415) + + try: + kw = parse( + body, datatypes, bodyarg=funcdef.body_type is not None + ) + except exception.UnknownArgument: + if not funcdef.ignore_extra_args: + raise + kw = {} + + return (), kw + + +def combine_args(funcdef, akw, allow_override=False): + newargs, newkwargs = [], {} + for args, kwargs in akw: + for i, arg in enumerate(args): + n = funcdef.arguments[i].name + if not allow_override and n in newkwargs: + raise exception.ClientSideError( + "Parameter %s was given several times" % n) + newkwargs[n] = arg + for name, value in kwargs.items(): + n = str(name) + if not allow_override and n in newkwargs: + raise exception.ClientSideError( + "Parameter %s was given several times" % n) + newkwargs[n] = value + return newargs, newkwargs + + +def get_args(funcdef, args, kwargs, params, body, mimetype): + """Combine arguments from multiple sources + + Combine arguments from : + * the host framework args and kwargs + * the request params + * the request body + + Note that the host framework args and kwargs can be overridden + by arguments from params of body + + """ + # get the body from params if not given directly + if not body and '__body__' in params: + body = params['__body__'] + + # extract args from the host args and kwargs + from_args = args_from_args(funcdef, args, kwargs) + + # extract args from the request parameters + from_params = args_from_params(funcdef, params) + + # extract args from the request body + from_body = args_from_body(funcdef, body, mimetype) + + # combine params and body arguments + from_params_and_body = combine_args( + funcdef, + (from_params, from_body) + ) + + args, kwargs = combine_args( + funcdef, + (from_args, from_params_and_body), + allow_override=True + ) + check_arguments(funcdef, args, kwargs) + return args, kwargs + + +def check_arguments(funcdef, args, kw): + """Check if some arguments are missing""" + assert len(args) == 0 + for arg in funcdef.arguments: + if arg.mandatory and arg.name not in kw: + raise exception.MissingArgument(arg.name) diff --git a/ironic/api/expose.py b/ironic/api/expose.py index 71bfa15005..94bfb3b5a1 100644 --- a/ironic/api/expose.py +++ b/ironic/api/expose.py @@ -25,8 +25,8 @@ import traceback from oslo_config import cfg from oslo_log import log import pecan -import wsme.rest.args +from ironic.api import args as api_args from ironic.api import functions from ironic.api import types as atypes @@ -70,8 +70,8 @@ def expose(*args, **kwargs): return_type = funcdef.return_type try: - args, kwargs = wsme.rest.args.get_args( - funcdef, args, kwargs, pecan.request.params, None, + args, kwargs = api_args.get_args( + funcdef, args, kwargs, pecan.request.params, pecan.request.body, pecan.request.content_type ) result = f(self, *args, **kwargs) diff --git a/ironic/api/types.py b/ironic/api/types.py index 0da12360b3..ae9efa55d4 100644 --- a/ironic/api/types.py +++ b/ironic/api/types.py @@ -22,6 +22,7 @@ from wsme.types import Enum # noqa from wsme.types import File # noqa from wsme.types import IntegerType # noqa from wsme.types import iscomplex # noqa +from wsme.types import isusertype # noqa from wsme.types import list_attributes # noqa from wsme.types import registry # noqa from wsme.types import StringType # noqa @@ -29,6 +30,7 @@ from wsme.types import text # noqa from wsme.types import Unset # noqa from wsme.types import UnsetType # noqa from wsme.types import UserType # noqa +from wsme.types import validate_value # noqa from wsme.types import wsattr # noqa from wsme.types import wsproperty # noqa diff --git a/ironic/common/exception.py b/ironic/common/exception.py index a314e75e01..7ce8ba8e7a 100644 --- a/ironic/common/exception.py +++ b/ironic/common/exception.py @@ -725,3 +725,75 @@ class NodeIsRetired(Invalid): class NoFreeIPMITerminalPorts(TemporaryFailure): _msg_fmt = _("Unable to allocate a free port on host %(host)s for IPMI " "terminal, not enough free ports.") + + +class InvalidInput(ClientSideError): + def __init__(self, fieldname, value, msg=''): + self.fieldname = fieldname + self.value = value + super(InvalidInput, self).__init__(msg) + + @property + def faultstring(self): + return _( + "Invalid input for field/attribute %(fieldname)s. " + "Value: '%(value)s'. %(msg)s" + ) % { + 'fieldname': self.fieldname, + 'value': self.value, + 'msg': self.msg + } + + +class UnknownArgument(ClientSideError): + def __init__(self, argname, msg=''): + self.argname = argname + super(UnknownArgument, self).__init__(msg) + + @property + def faultstring(self): + return _('Unknown argument: "%(argname)s"%(msg)s') % { + 'argname': self.argname, + 'msg': self.msg and ": " + self.msg or "" + } + + +class MissingArgument(ClientSideError): + def __init__(self, argname, msg=''): + self.argname = argname + super(MissingArgument, self).__init__(msg) + + @property + def faultstring(self): + return _('Missing argument: "%(argname)s"%(msg)s') % { + 'argname': self.argname, + 'msg': self.msg and ": " + self.msg or "" + } + + +class UnknownAttribute(ClientSideError): + def __init__(self, fieldname, attributes, msg=''): + self.fieldname = fieldname + self.attributes = attributes + self.msg = msg + super(UnknownAttribute, self).__init__(self.msg) + + @property + def faultstring(self): + error = _("Unknown attribute for argument %(argn)s: %(attrs)s") + if len(self.attributes) > 1: + error = _("Unknown attributes for argument %(argn)s: %(attrs)s") + str_attrs = ", ".join(self.attributes) + return error % {'argn': self.fieldname, 'attrs': str_attrs} + + def add_fieldname(self, name): + """Add a fieldname to concatenate the full name. + + Add a fieldname so that the whole hierarchy is displayed. Successive + calls to this method will prepend ``name`` to the hierarchy of names. + """ + if self.fieldname is not None: + self.fieldname = "{}.{}".format(name, self.fieldname) + else: + self.fieldname = name + super(UnknownAttribute, self).__init__(self.msg) diff --git a/ironic/common/utils.py b/ironic/common/utils.py index 2d389af593..4141973e13 100644 --- a/ironic/common/utils.py +++ b/ironic/common/utils.py @@ -42,6 +42,15 @@ from ironic.conf import CONF LOG = logging.getLogger(__name__) +DATE_RE = r'(?P-?\d{4,})-(?P\d{2})-(?P\d{2})' +TIME_RE = r'(?P\d{2}):(?P\d{2}):(?P\d{2})' + \ + r'(\.(?P\d+))?' +TZ_RE = r'((?P[+-])(?P\d{2}):(?P\d{2}))' + \ + r'|(?PZ)' + +DATETIME_RE = re.compile( + '%sT%s(%s)?' % (DATE_RE, TIME_RE, TZ_RE)) + warn_deprecated_extra_vif_port_id = False diff --git a/ironic/tests/unit/api/test_args.py b/ironic/tests/unit/api/test_args.py new file mode 100644 index 0000000000..57c4f5feac --- /dev/null +++ b/ironic/tests/unit/api/test_args.py @@ -0,0 +1,512 @@ +# Copyright 2020 Red Hat, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import datetime +import decimal +import io + +from webob import multidict + +from ironic.api import args +from ironic.api.controllers.v1 import types +from ironic.api import functions +from ironic.api import types as atypes +from ironic.common import exception +from ironic.tests import base as test_base + + +class Obj(atypes.Base): + + id = atypes.wsattr(int, mandatory=True) + name = str + readonly_field = atypes.wsattr(str, readonly=True) + default_field = atypes.wsattr(str, default='foo') + unset_me = str + + +class NestedObj(atypes.Base): + o = Obj + + +class TestArgs(test_base.TestCase): + + def test_fromjson_array(self): + atype = atypes.ArrayType(int) + self.assertEqual( + [0, 1, 1234, None], + args.fromjson_array(atype, [0, '1', '1_234', None]) + ) + self.assertRaises(ValueError, args.fromjson_array, + atype, ['one', 'two', 'three']) + self.assertRaises(ValueError, args.fromjson_array, + atype, 'one') + + def test_fromjson_dict(self): + dtype = atypes.DictType(str, int) + self.assertEqual({ + 'zero': 0, + 'one': 1, + 'etc': 1234, + 'none': None + }, args.fromjson_dict(dtype, { + 'zero': 0, + 'one': '1', + 'etc': '1_234', + 'none': None + })) + + self.assertRaises(ValueError, args.fromjson_dict, + dtype, []) + self.assertRaises(ValueError, args.fromjson_dict, + dtype, {'one': 'one'}) + + def test_fromjson_bool(self): + for b in (1, 2, True, 'true', 't', 'yes', 'y', 'on', '1'): + self.assertTrue(args.fromjson_bool(b)) + for b in (0, False, 'false', 'f', 'no', 'n', 'off', '0'): + self.assertFalse(args.fromjson_bool(b)) + for b in ('yup', 'yeet', 'NOPE', 3.14): + self.assertRaises(ValueError, args.fromjson_bool, b) + + def test_fromjson(self): + # parse None + self.assertIsNone(args.fromjson(None, None)) + + # parse array + atype = atypes.ArrayType(int) + self.assertEqual( + [0, 1, 1234, None], + args.fromjson(atype, [0, '1', '1_234', None]) + ) + + # parse dict + dtype = atypes.DictType(str, int) + self.assertEqual({ + 'zero': 0, + 'one': 1, + 'etc': 1234, + 'none': None + }, args.fromjson(dtype, { + 'zero': 0, + 'one': '1', + 'etc': '1_234', + 'none': None + })) + + # parse bytes + self.assertEqual( + b'asdf', + args.fromjson(bytes, b'asdf') + ) + self.assertEqual( + b'asdf', + args.fromjson(bytes, 'asdf') + ) + self.assertEqual( + b'33', + args.fromjson(bytes, 33) + ) + self.assertEqual( + b'3.14', + args.fromjson(bytes, 3.14) + ) + + # parse str + self.assertEqual( + 'asdf', + args.fromjson(str, b'asdf') + ) + self.assertEqual( + 'asdf', + args.fromjson(str, 'asdf') + ) + + # parse int/float + self.assertEqual( + 3, + args.fromjson(int, '3') + ) + self.assertEqual( + 3, + args.fromjson(int, 3) + ) + self.assertEqual( + 3.14, + args.fromjson(float, 3.14) + ) + + # parse bool + self.assertFalse(args.fromjson(bool, 'no')) + self.assertTrue(args.fromjson(bool, 'yes')) + + # parse decimal + self.assertEqual( + decimal.Decimal(3.14), + args.fromjson(decimal.Decimal, 3.14) + ) + + # parse datetime + expected = datetime.datetime(2015, 8, 13, 11, 38, 9, 496475) + self.assertEqual( + expected, + args.fromjson(datetime.datetime, '2015-08-13T11:38:09.496475') + ) + + # parse complex + n = args.fromjson(NestedObj, {'o': { + 'id': 1234, + 'name': 'an object' + }}) + self.assertIsInstance(n.o, Obj) + self.assertEqual(1234, n.o.id) + self.assertEqual('an object', n.o.name) + self.assertEqual('foo', n.o.default_field) + + # parse usertype + self.assertEqual( + ['0', '1', '2', 'three'], + args.fromjson(types.listtype, '0,1, 2, three') + ) + + def test_fromjson_complex(self): + n = args.fromjson_complex(NestedObj, {'o': { + 'id': 1234, + 'name': 'an object' + }}) + self.assertIsInstance(n.o, Obj) + self.assertEqual(1234, n.o.id) + self.assertEqual('an object', n.o.name) + self.assertEqual('foo', n.o.default_field) + + e = self.assertRaises(exception.UnknownAttribute, + args.fromjson_complex, + Obj, {'ooo': {}}) + self.assertEqual({'ooo'}, e.attributes) + + e = self.assertRaises(exception.InvalidInput, args.fromjson_complex, + Obj, + {'name': 'an object'}) + self.assertEqual('id', e.fieldname) + self.assertEqual('Mandatory field missing.', e.msg) + + e = self.assertRaises(exception.InvalidInput, args.fromjson_complex, + Obj, + {'id': 1234, 'readonly_field': 'foo'}) + self.assertEqual('readonly_field', e.fieldname) + self.assertEqual('Cannot set read only field.', e.msg) + + def test_parse(self): + # source as bytes + s = b'{"o": {"id": 1234, "name": "an object"}}' + + # test bodyarg=True + n = args.parse(s, {"o": NestedObj}, True)['o'] + self.assertEqual(1234, n.o.id) + self.assertEqual('an object', n.o.name) + + # source as file + s = io.StringIO('{"o": {"id": 1234, "name": "an object"}}') + + # test bodyarg=False + n = args.parse(s, {"o": Obj}, False)['o'] + self.assertEqual(1234, n.id) + self.assertEqual('an object', n.name) + + # fromjson ValueError + s = '{"o": ["id", "name"]}' + self.assertRaises(exception.InvalidInput, args.parse, + s, {"o": atypes.DictType(str, str)}, False) + s = '["id", "name"]' + self.assertRaises(exception.InvalidInput, args.parse, + s, {"o": atypes.DictType(str, str)}, True) + + # fromjson UnknownAttribute + s = '{"o": {"foo": "bar", "id": 1234, "name": "an object"}}' + self.assertRaises(exception.UnknownAttribute, args.parse, + s, {"o": NestedObj}, True) + self.assertRaises(exception.UnknownAttribute, args.parse, + s, {"o": Obj}, False) + + # invalid json + s = '{Sunn O)))}' + self.assertRaises(exception.ClientSideError, args.parse, + s, {"o": Obj}, False) + + # extra args + s = '{"foo": "bar", "o": {"id": 1234, "name": "an object"}}' + self.assertRaises(exception.UnknownArgument, args.parse, + s, {"o": Obj}, False) + + def test_from_param(self): + # datetime param + expected = datetime.datetime(2015, 8, 13, 11, 38, 9, 496475) + self.assertEqual( + expected, + args.from_param(datetime.datetime, '2015-08-13T11:38:09.496475') + ) + self.assertIsNone(args.from_param(datetime.datetime, None)) + + # file param + self.assertEqual( + b'foo', + args.from_param(atypes.File, b'foo').content + ) + + # usertype param + self.assertEqual( + ['0', '1', '2', 'three'], + args.from_param(types.listtype, '0,1, 2, three') + ) + + # array param + atype = atypes.ArrayType(int) + self.assertEqual( + [0, 1, 1234, None], + args.from_param(atype, [0, '1', '1_234', None]) + ) + self.assertIsNone(args.from_param(atype, None)) + + # string param + self.assertEqual('foo', args.from_param(str, 'foo')) + self.assertIsNone(args.from_param(str, None)) + + # string param with from_params + hit_paths = set() + params = multidict.MultiDict( + foo='bar', + ) + self.assertEqual( + 'bar', + args.from_params(str, params, 'foo', hit_paths) + ) + self.assertEqual({'foo'}, hit_paths) + + def test_array_from_params(self): + hit_paths = set() + datatype = atypes.ArrayType(str) + params = multidict.MultiDict( + foo='bar', + one='two' + ) + self.assertEqual( + ['bar'], + args.from_params(datatype, params, 'foo', hit_paths) + ) + self.assertEqual({'foo'}, hit_paths) + self.assertEqual( + ['two'], + args.array_from_params(datatype, params, 'one', hit_paths) + ) + self.assertEqual({'foo', 'one'}, hit_paths) + + def test_usertype_from_params(self): + hit_paths = set() + datatype = types.listtype + params = multidict.MultiDict( + foo='0,1, 2, three', + ) + self.assertEqual( + ['0', '1', '2', 'three'], + args.usertype_from_params(datatype, params, 'foo', hit_paths) + ) + self.assertEqual( + ['0', '1', '2', 'three'], + args.from_params(datatype, params, 'foo', hit_paths) + ) + self.assertEqual( + atypes.Unset, + args.usertype_from_params(datatype, params, 'bar', hit_paths) + ) + + def test_args_from_args(self): + + fromargs = ['one', 2, [0, '1', '2_34']] + fromkwargs = {'foo': '1, 2, 3'} + + @functions.signature(str, str, int, atypes.ArrayType(int), + types.listtype) + def myfunc(self, first, second, third, foo): + pass + funcdef = functions.FunctionDefinition.get(myfunc) + + newargs, newkwargs = args.args_from_args(funcdef, fromargs, fromkwargs) + self.assertEqual(['one', 2, [0, 1, 234]], newargs) + self.assertEqual({'foo': ['1', '2', '3']}, newkwargs) + + def test_args_from_params(self): + + @functions.signature(str, str, int, atypes.ArrayType(int), + types.listtype) + def myfunc(self, first, second, third, foo): + pass + funcdef = functions.FunctionDefinition.get(myfunc) + params = multidict.MultiDict( + foo='0,1, 2, three', + third='1', + second='2' + ) + self.assertEqual( + ([], {'foo': ['0', '1', '2', 'three'], 'second': 2, 'third': [1]}), + args.args_from_params(funcdef, params) + ) + + # unexpected param + params = multidict.MultiDict(bar='baz') + self.assertRaises(exception.UnknownArgument, args.args_from_params, + funcdef, params) + + # no params plus a body + params = multidict.MultiDict(__body__='') + self.assertEqual( + ([], {}), + args.args_from_params(funcdef, params) + ) + + def test_args_from_body(self): + @functions.signature(str, body=NestedObj) + def myfunc(self, nested): + pass + funcdef = functions.FunctionDefinition.get(myfunc) + mimetype = 'application/json' + body = b'{"o": {"id": 1234, "name": "an object"}}' + newargs, newkwargs = args.args_from_body(funcdef, body, mimetype) + + self.assertEqual(1234, newkwargs['nested'].o.id) + self.assertEqual('an object', newkwargs['nested'].o.name) + + self.assertEqual( + ((), {}), + args.args_from_body(funcdef, None, mimetype) + ) + + self.assertRaises(exception.ClientSideError, args.args_from_body, + funcdef, body, 'application/x-corba') + + self.assertEqual( + ((), {}), + args.args_from_body(funcdef, body, + 'application/x-www-form-urlencoded') + ) + + def test_combine_args(self): + + @functions.signature(str, str, int) + def myfunc(self, first, second,): + pass + funcdef = functions.FunctionDefinition.get(myfunc) + + # empty + self.assertEqual( + ([], {}), + args.combine_args( + funcdef, ( + ([], {}), + ([], {}), + ) + ) + ) + + # combine kwargs + self.assertEqual( + ([], {'first': 'one', 'second': 'two'}), + args.combine_args( + funcdef, ( + ([], {}), + ([], {'first': 'one', 'second': 'two'}), + ) + ) + ) + + # combine mixed args + self.assertEqual( + ([], {'first': 'one', 'second': 'two'}), + args.combine_args( + funcdef, ( + (['one'], {}), + ([], {'second': 'two'}), + ) + ) + ) + + # override kwargs + self.assertEqual( + ([], {'first': 'two'}), + args.combine_args( + funcdef, ( + ([], {'first': 'one'}), + ([], {'first': 'two'}), + ), + allow_override=True + ) + ) + + # override args + self.assertEqual( + ([], {'first': 'two', 'second': 'three'}), + args.combine_args( + funcdef, ( + (['one', 'three'], {}), + (['two'], {}), + ), + allow_override=True + ) + ) + + # can't override args + self.assertRaises(exception.ClientSideError, args.combine_args, + funcdef, + ((['one'], {}), (['two'], {}))) + + # can't override kwargs + self.assertRaises(exception.ClientSideError, args.combine_args, + funcdef, + (([], {'first': 'one'}), ([], {'first': 'two'}))) + + def test_get_args(self): + @functions.signature(str, str, int, atypes.ArrayType(int), + types.listtype, body=NestedObj) + def myfunc(self, first, second, third, foo, nested): + pass + funcdef = functions.FunctionDefinition.get(myfunc) + params = multidict.MultiDict( + foo='0,1, 2, three', + second='2' + ) + mimetype = 'application/json' + body = b'{"o": {"id": 1234, "name": "an object"}}' + fromargs = ['one'] + fromkwargs = {'third': '1'} + + newargs, newkwargs = args.get_args(funcdef, fromargs, fromkwargs, + params, body, mimetype) + self.assertEqual([], newargs) + n = newkwargs.pop('nested') + self.assertEqual({ + 'first': 'one', + 'foo': ['0', '1', '2', 'three'], + 'second': 2, + 'third': [1]}, + newkwargs + ) + self.assertEqual(1234, n.o.id) + self.assertEqual('an object', n.o.name) + + # check_arguments missing mandatory argument 'second' + params = multidict.MultiDict( + foo='0,1, 2, three', + ) + self.assertRaises(exception.MissingArgument, args.get_args, + funcdef, fromargs, fromkwargs, + params, body, mimetype)