Add json and param parsing to args

Some unused HTTP param to arg parsing has not been implemented to
reduce code complexity. This includes the following types:
- DictType
- complex types

Asserts are added to confirm these param types are not used in ironic
currently, and to prevent them being used in future development.

Story: 1651346
Task: 10551

Change-Id: Idfcf99216f10e8928fe4ba6202a7d69bfa916459
This commit is contained in:
Steve Baker 2020-01-28 14:07:11 +13:00
parent 0e65f0134d
commit 8006c9dfd2
6 changed files with 985 additions and 3 deletions

387
ironic/api/args.py Normal file
View File

@ -0,0 +1,387 @@
# Copyright 2011-2019 the WSME authors and contributors
# (See https://opendev.org/x/wsme/)
#
# This module is part of WSME and is also released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import cgi
import datetime
import decimal
import json
import logging
from dateutil import parser as dateparser
from ironic.api import types as atypes
from ironic.common import exception
LOG = logging.getLogger(__name__)
CONTENT_TYPE = 'application/json'
ACCEPT_CONTENT_TYPES = [
CONTENT_TYPE,
'text/javascript',
'application/javascript'
]
ENUM_TRUE = ('true', 't', 'yes', 'y', 'on', '1')
ENUM_FALSE = ('false', 'f', 'no', 'n', 'off', '0')
def fromjson_array(datatype, value):
if not isinstance(value, list):
raise ValueError("Value not a valid list: %s" % value)
return [fromjson(datatype.item_type, item) for item in value]
def fromjson_dict(datatype, value):
if not isinstance(value, dict):
raise ValueError("Value not a valid dict: %s" % value)
return dict((
(fromjson(datatype.key_type, item[0]),
fromjson(datatype.value_type, item[1]))
for item in value.items()))
def fromjson_bool(value):
if isinstance(value, (int, bool)):
return bool(value)
if value in ENUM_TRUE:
return True
if value in ENUM_FALSE:
return False
raise ValueError("Value not an unambiguous boolean: %s" % value)
def fromjson(datatype, value):
"""A generic converter from json base types to python datatype.
"""
if value is None:
return None
if isinstance(datatype, atypes.ArrayType):
return fromjson_array(datatype, value)
if isinstance(datatype, atypes.DictType):
return fromjson_dict(datatype, value)
if datatype is bytes:
if isinstance(value, (str, int, float)):
return str(value).encode('utf8')
return value
if datatype is str:
if isinstance(value, bytes):
return value.decode('utf-8')
return value
if datatype in (int, float):
return datatype(value)
if datatype is bool:
return fromjson_bool(value)
if datatype is decimal.Decimal:
return decimal.Decimal(value)
if datatype is datetime.datetime:
return dateparser.parse(value)
if atypes.iscomplex(datatype):
return fromjson_complex(datatype, value)
if atypes.isusertype(datatype):
return datatype.frombasetype(fromjson(datatype.basetype, value))
return value
def fromjson_complex(datatype, value):
obj = datatype()
attributes = atypes.list_attributes(datatype)
# Here we check that all the attributes in the value are also defined
# in our type definition, otherwise we raise an Error.
v_keys = set(value.keys())
a_keys = set(adef.name for adef in attributes)
if not v_keys <= a_keys:
raise exception.UnknownAttribute(None, v_keys - a_keys)
for attrdef in attributes:
if attrdef.name in value:
try:
val_fromjson = fromjson(attrdef.datatype,
value[attrdef.name])
except exception.UnknownAttribute as e:
e.add_fieldname(attrdef.name)
raise
if getattr(attrdef, 'readonly', False):
raise exception.InvalidInput(attrdef.name, val_fromjson,
"Cannot set read only field.")
setattr(obj, attrdef.key, val_fromjson)
elif attrdef.mandatory:
raise exception.InvalidInput(attrdef.name, None,
"Mandatory field missing.")
return atypes.validate_value(datatype, obj)
def parse(s, datatypes, bodyarg, encoding='utf8'):
jload = json.load
if not hasattr(s, 'read'):
if isinstance(s, bytes):
s = s.decode(encoding)
jload = json.loads
try:
jdata = jload(s)
except ValueError:
raise exception.ClientSideError("Request is not in valid JSON format")
if bodyarg:
argname = list(datatypes.keys())[0]
try:
kw = {argname: fromjson(datatypes[argname], jdata)}
except ValueError as e:
raise exception.InvalidInput(argname, jdata, e.args[0])
except exception.UnknownAttribute as e:
# We only know the fieldname at this level, not in the
# called function. We fill in this information here.
e.add_fieldname(argname)
raise
else:
kw = {}
extra_args = []
if not isinstance(jdata, dict):
raise exception.ClientSideError("Request must be a JSON dict")
for key in jdata:
if key not in datatypes:
extra_args.append(key)
else:
try:
kw[key] = fromjson(datatypes[key], jdata[key])
except ValueError as e:
raise exception.InvalidInput(key, jdata[key], e.args[0])
except exception.UnknownAttribute as e:
# We only know the fieldname at this level, not in the
# called function. We fill in this information here.
e.add_fieldname(key)
raise
if extra_args:
raise exception.UnknownArgument(', '.join(extra_args))
return kw
def from_param(datatype, value):
if datatype is datetime.datetime:
return dateparser.parse(value) if value else None
if datatype is atypes.File:
if isinstance(value, cgi.FieldStorage):
return atypes.File(fieldstorage=value)
return atypes.File(content=value)
if isinstance(datatype, atypes.UserType):
return datatype.frombasetype(
from_param(datatype.basetype, value))
if isinstance(datatype, atypes.ArrayType):
if value is None:
return value
return [
from_param(datatype.item_type, item)
for item in value
]
return datatype(value) if value is not None else None
def from_params(datatype, params, path, hit_paths):
if isinstance(datatype, atypes.ArrayType):
return array_from_params(datatype, params, path, hit_paths)
if isinstance(datatype, atypes.UserType):
return usertype_from_params(datatype, params, path, hit_paths)
if path in params:
assert not isinstance(datatype, atypes.DictType), \
'DictType unsupported'
assert not atypes.iscomplex(datatype) or datatype is atypes.File, \
'complex type unsupported'
hit_paths.add(path)
return from_param(datatype, params[path])
return atypes.Unset
def array_from_params(datatype, params, path, hit_paths):
if hasattr(params, 'getall'):
# webob multidict
def getall(params, path):
return params.getall(path)
elif hasattr(params, 'getlist'):
# werkzeug multidict
def getall(params, path): # noqa
return params.getlist(path)
if path in params:
hit_paths.add(path)
return [
from_param(datatype.item_type, value)
for value in getall(params, path)]
return atypes.Unset
def usertype_from_params(datatype, params, path, hit_paths):
if path in params:
hit_paths.add(path)
value = from_param(datatype.basetype, params[path])
if value is not atypes.Unset:
return datatype.frombasetype(value)
return atypes.Unset
def args_from_args(funcdef, args, kwargs):
newargs = []
for argdef, arg in zip(funcdef.arguments[:len(args)], args):
try:
newargs.append(from_param(argdef.datatype, arg))
except Exception as e:
if isinstance(argdef.datatype, atypes.UserType):
datatype_name = argdef.datatype.name
elif isinstance(argdef.datatype, type):
datatype_name = argdef.datatype.__name__
else:
datatype_name = argdef.datatype.__class__.__name__
raise exception.InvalidInput(
argdef.name,
arg,
"unable to convert to %(datatype)s. Error: %(error)s" % {
'datatype': datatype_name, 'error': e})
newkwargs = {}
for argname, value in kwargs.items():
newkwargs[argname] = from_param(
funcdef.get_arg(argname).datatype, value
)
return newargs, newkwargs
def args_from_params(funcdef, params):
kw = {}
hit_paths = set()
for argdef in funcdef.arguments:
value = from_params(
argdef.datatype, params, argdef.name, hit_paths)
if value is not atypes.Unset:
kw[argdef.name] = value
paths = set(params.keys())
unknown_paths = paths - hit_paths
if '__body__' in unknown_paths:
unknown_paths.remove('__body__')
if not funcdef.ignore_extra_args and unknown_paths:
raise exception.UnknownArgument(', '.join(unknown_paths))
return [], kw
def args_from_body(funcdef, body, mimetype):
if funcdef.body_type is not None:
datatypes = {funcdef.arguments[-1].name: funcdef.body_type}
else:
datatypes = dict(((a.name, a.datatype) for a in funcdef.arguments))
if not body:
return (), {}
if mimetype == "application/x-www-form-urlencoded":
# the parameters should have been parsed in params
return (), {}
elif mimetype not in ACCEPT_CONTENT_TYPES:
raise exception.ClientSideError("Unknown mimetype: %s" % mimetype,
status_code=415)
try:
kw = parse(
body, datatypes, bodyarg=funcdef.body_type is not None
)
except exception.UnknownArgument:
if not funcdef.ignore_extra_args:
raise
kw = {}
return (), kw
def combine_args(funcdef, akw, allow_override=False):
newargs, newkwargs = [], {}
for args, kwargs in akw:
for i, arg in enumerate(args):
n = funcdef.arguments[i].name
if not allow_override and n in newkwargs:
raise exception.ClientSideError(
"Parameter %s was given several times" % n)
newkwargs[n] = arg
for name, value in kwargs.items():
n = str(name)
if not allow_override and n in newkwargs:
raise exception.ClientSideError(
"Parameter %s was given several times" % n)
newkwargs[n] = value
return newargs, newkwargs
def get_args(funcdef, args, kwargs, params, body, mimetype):
"""Combine arguments from multiple sources
Combine arguments from :
* the host framework args and kwargs
* the request params
* the request body
Note that the host framework args and kwargs can be overridden
by arguments from params of body
"""
# get the body from params if not given directly
if not body and '__body__' in params:
body = params['__body__']
# extract args from the host args and kwargs
from_args = args_from_args(funcdef, args, kwargs)
# extract args from the request parameters
from_params = args_from_params(funcdef, params)
# extract args from the request body
from_body = args_from_body(funcdef, body, mimetype)
# combine params and body arguments
from_params_and_body = combine_args(
funcdef,
(from_params, from_body)
)
args, kwargs = combine_args(
funcdef,
(from_args, from_params_and_body),
allow_override=True
)
check_arguments(funcdef, args, kwargs)
return args, kwargs
def check_arguments(funcdef, args, kw):
"""Check if some arguments are missing"""
assert len(args) == 0
for arg in funcdef.arguments:
if arg.mandatory and arg.name not in kw:
raise exception.MissingArgument(arg.name)

View File

@ -25,8 +25,8 @@ import traceback
from oslo_config import cfg
from oslo_log import log
import pecan
import wsme.rest.args
from ironic.api import args as api_args
from ironic.api import functions
from ironic.api import types as atypes
@ -70,8 +70,8 @@ def expose(*args, **kwargs):
return_type = funcdef.return_type
try:
args, kwargs = wsme.rest.args.get_args(
funcdef, args, kwargs, pecan.request.params, None,
args, kwargs = api_args.get_args(
funcdef, args, kwargs, pecan.request.params,
pecan.request.body, pecan.request.content_type
)
result = f(self, *args, **kwargs)

View File

@ -22,6 +22,7 @@ from wsme.types import Enum # noqa
from wsme.types import File # noqa
from wsme.types import IntegerType # noqa
from wsme.types import iscomplex # noqa
from wsme.types import isusertype # noqa
from wsme.types import list_attributes # noqa
from wsme.types import registry # noqa
from wsme.types import StringType # noqa
@ -29,6 +30,7 @@ from wsme.types import text # noqa
from wsme.types import Unset # noqa
from wsme.types import UnsetType # noqa
from wsme.types import UserType # noqa
from wsme.types import validate_value # noqa
from wsme.types import wsattr # noqa
from wsme.types import wsproperty # noqa

View File

@ -725,3 +725,75 @@ class NodeIsRetired(Invalid):
class NoFreeIPMITerminalPorts(TemporaryFailure):
_msg_fmt = _("Unable to allocate a free port on host %(host)s for IPMI "
"terminal, not enough free ports.")
class InvalidInput(ClientSideError):
def __init__(self, fieldname, value, msg=''):
self.fieldname = fieldname
self.value = value
super(InvalidInput, self).__init__(msg)
@property
def faultstring(self):
return _(
"Invalid input for field/attribute %(fieldname)s. "
"Value: '%(value)s'. %(msg)s"
) % {
'fieldname': self.fieldname,
'value': self.value,
'msg': self.msg
}
class UnknownArgument(ClientSideError):
def __init__(self, argname, msg=''):
self.argname = argname
super(UnknownArgument, self).__init__(msg)
@property
def faultstring(self):
return _('Unknown argument: "%(argname)s"%(msg)s') % {
'argname': self.argname,
'msg': self.msg and ": " + self.msg or ""
}
class MissingArgument(ClientSideError):
def __init__(self, argname, msg=''):
self.argname = argname
super(MissingArgument, self).__init__(msg)
@property
def faultstring(self):
return _('Missing argument: "%(argname)s"%(msg)s') % {
'argname': self.argname,
'msg': self.msg and ": " + self.msg or ""
}
class UnknownAttribute(ClientSideError):
def __init__(self, fieldname, attributes, msg=''):
self.fieldname = fieldname
self.attributes = attributes
self.msg = msg
super(UnknownAttribute, self).__init__(self.msg)
@property
def faultstring(self):
error = _("Unknown attribute for argument %(argn)s: %(attrs)s")
if len(self.attributes) > 1:
error = _("Unknown attributes for argument %(argn)s: %(attrs)s")
str_attrs = ", ".join(self.attributes)
return error % {'argn': self.fieldname, 'attrs': str_attrs}
def add_fieldname(self, name):
"""Add a fieldname to concatenate the full name.
Add a fieldname so that the whole hierarchy is displayed. Successive
calls to this method will prepend ``name`` to the hierarchy of names.
"""
if self.fieldname is not None:
self.fieldname = "{}.{}".format(name, self.fieldname)
else:
self.fieldname = name
super(UnknownAttribute, self).__init__(self.msg)

View File

@ -42,6 +42,15 @@ from ironic.conf import CONF
LOG = logging.getLogger(__name__)
DATE_RE = r'(?P<year>-?\d{4,})-(?P<month>\d{2})-(?P<day>\d{2})'
TIME_RE = r'(?P<hour>\d{2}):(?P<min>\d{2}):(?P<sec>\d{2})' + \
r'(\.(?P<sec_frac>\d+))?'
TZ_RE = r'((?P<tz_sign>[+-])(?P<tz_hour>\d{2}):(?P<tz_min>\d{2}))' + \
r'|(?P<tz_z>Z)'
DATETIME_RE = re.compile(
'%sT%s(%s)?' % (DATE_RE, TIME_RE, TZ_RE))
warn_deprecated_extra_vif_port_id = False

View File

@ -0,0 +1,512 @@
# Copyright 2020 Red Hat, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
import decimal
import io
from webob import multidict
from ironic.api import args
from ironic.api.controllers.v1 import types
from ironic.api import functions
from ironic.api import types as atypes
from ironic.common import exception
from ironic.tests import base as test_base
class Obj(atypes.Base):
id = atypes.wsattr(int, mandatory=True)
name = str
readonly_field = atypes.wsattr(str, readonly=True)
default_field = atypes.wsattr(str, default='foo')
unset_me = str
class NestedObj(atypes.Base):
o = Obj
class TestArgs(test_base.TestCase):
def test_fromjson_array(self):
atype = atypes.ArrayType(int)
self.assertEqual(
[0, 1, 1234, None],
args.fromjson_array(atype, [0, '1', '1_234', None])
)
self.assertRaises(ValueError, args.fromjson_array,
atype, ['one', 'two', 'three'])
self.assertRaises(ValueError, args.fromjson_array,
atype, 'one')
def test_fromjson_dict(self):
dtype = atypes.DictType(str, int)
self.assertEqual({
'zero': 0,
'one': 1,
'etc': 1234,
'none': None
}, args.fromjson_dict(dtype, {
'zero': 0,
'one': '1',
'etc': '1_234',
'none': None
}))
self.assertRaises(ValueError, args.fromjson_dict,
dtype, [])
self.assertRaises(ValueError, args.fromjson_dict,
dtype, {'one': 'one'})
def test_fromjson_bool(self):
for b in (1, 2, True, 'true', 't', 'yes', 'y', 'on', '1'):
self.assertTrue(args.fromjson_bool(b))
for b in (0, False, 'false', 'f', 'no', 'n', 'off', '0'):
self.assertFalse(args.fromjson_bool(b))
for b in ('yup', 'yeet', 'NOPE', 3.14):
self.assertRaises(ValueError, args.fromjson_bool, b)
def test_fromjson(self):
# parse None
self.assertIsNone(args.fromjson(None, None))
# parse array
atype = atypes.ArrayType(int)
self.assertEqual(
[0, 1, 1234, None],
args.fromjson(atype, [0, '1', '1_234', None])
)
# parse dict
dtype = atypes.DictType(str, int)
self.assertEqual({
'zero': 0,
'one': 1,
'etc': 1234,
'none': None
}, args.fromjson(dtype, {
'zero': 0,
'one': '1',
'etc': '1_234',
'none': None
}))
# parse bytes
self.assertEqual(
b'asdf',
args.fromjson(bytes, b'asdf')
)
self.assertEqual(
b'asdf',
args.fromjson(bytes, 'asdf')
)
self.assertEqual(
b'33',
args.fromjson(bytes, 33)
)
self.assertEqual(
b'3.14',
args.fromjson(bytes, 3.14)
)
# parse str
self.assertEqual(
'asdf',
args.fromjson(str, b'asdf')
)
self.assertEqual(
'asdf',
args.fromjson(str, 'asdf')
)
# parse int/float
self.assertEqual(
3,
args.fromjson(int, '3')
)
self.assertEqual(
3,
args.fromjson(int, 3)
)
self.assertEqual(
3.14,
args.fromjson(float, 3.14)
)
# parse bool
self.assertFalse(args.fromjson(bool, 'no'))
self.assertTrue(args.fromjson(bool, 'yes'))
# parse decimal
self.assertEqual(
decimal.Decimal(3.14),
args.fromjson(decimal.Decimal, 3.14)
)
# parse datetime
expected = datetime.datetime(2015, 8, 13, 11, 38, 9, 496475)
self.assertEqual(
expected,
args.fromjson(datetime.datetime, '2015-08-13T11:38:09.496475')
)
# parse complex
n = args.fromjson(NestedObj, {'o': {
'id': 1234,
'name': 'an object'
}})
self.assertIsInstance(n.o, Obj)
self.assertEqual(1234, n.o.id)
self.assertEqual('an object', n.o.name)
self.assertEqual('foo', n.o.default_field)
# parse usertype
self.assertEqual(
['0', '1', '2', 'three'],
args.fromjson(types.listtype, '0,1, 2, three')
)
def test_fromjson_complex(self):
n = args.fromjson_complex(NestedObj, {'o': {
'id': 1234,
'name': 'an object'
}})
self.assertIsInstance(n.o, Obj)
self.assertEqual(1234, n.o.id)
self.assertEqual('an object', n.o.name)
self.assertEqual('foo', n.o.default_field)
e = self.assertRaises(exception.UnknownAttribute,
args.fromjson_complex,
Obj, {'ooo': {}})
self.assertEqual({'ooo'}, e.attributes)
e = self.assertRaises(exception.InvalidInput, args.fromjson_complex,
Obj,
{'name': 'an object'})
self.assertEqual('id', e.fieldname)
self.assertEqual('Mandatory field missing.', e.msg)
e = self.assertRaises(exception.InvalidInput, args.fromjson_complex,
Obj,
{'id': 1234, 'readonly_field': 'foo'})
self.assertEqual('readonly_field', e.fieldname)
self.assertEqual('Cannot set read only field.', e.msg)
def test_parse(self):
# source as bytes
s = b'{"o": {"id": 1234, "name": "an object"}}'
# test bodyarg=True
n = args.parse(s, {"o": NestedObj}, True)['o']
self.assertEqual(1234, n.o.id)
self.assertEqual('an object', n.o.name)
# source as file
s = io.StringIO('{"o": {"id": 1234, "name": "an object"}}')
# test bodyarg=False
n = args.parse(s, {"o": Obj}, False)['o']
self.assertEqual(1234, n.id)
self.assertEqual('an object', n.name)
# fromjson ValueError
s = '{"o": ["id", "name"]}'
self.assertRaises(exception.InvalidInput, args.parse,
s, {"o": atypes.DictType(str, str)}, False)
s = '["id", "name"]'
self.assertRaises(exception.InvalidInput, args.parse,
s, {"o": atypes.DictType(str, str)}, True)
# fromjson UnknownAttribute
s = '{"o": {"foo": "bar", "id": 1234, "name": "an object"}}'
self.assertRaises(exception.UnknownAttribute, args.parse,
s, {"o": NestedObj}, True)
self.assertRaises(exception.UnknownAttribute, args.parse,
s, {"o": Obj}, False)
# invalid json
s = '{Sunn O)))}'
self.assertRaises(exception.ClientSideError, args.parse,
s, {"o": Obj}, False)
# extra args
s = '{"foo": "bar", "o": {"id": 1234, "name": "an object"}}'
self.assertRaises(exception.UnknownArgument, args.parse,
s, {"o": Obj}, False)
def test_from_param(self):
# datetime param
expected = datetime.datetime(2015, 8, 13, 11, 38, 9, 496475)
self.assertEqual(
expected,
args.from_param(datetime.datetime, '2015-08-13T11:38:09.496475')
)
self.assertIsNone(args.from_param(datetime.datetime, None))
# file param
self.assertEqual(
b'foo',
args.from_param(atypes.File, b'foo').content
)
# usertype param
self.assertEqual(
['0', '1', '2', 'three'],
args.from_param(types.listtype, '0,1, 2, three')
)
# array param
atype = atypes.ArrayType(int)
self.assertEqual(
[0, 1, 1234, None],
args.from_param(atype, [0, '1', '1_234', None])
)
self.assertIsNone(args.from_param(atype, None))
# string param
self.assertEqual('foo', args.from_param(str, 'foo'))
self.assertIsNone(args.from_param(str, None))
# string param with from_params
hit_paths = set()
params = multidict.MultiDict(
foo='bar',
)
self.assertEqual(
'bar',
args.from_params(str, params, 'foo', hit_paths)
)
self.assertEqual({'foo'}, hit_paths)
def test_array_from_params(self):
hit_paths = set()
datatype = atypes.ArrayType(str)
params = multidict.MultiDict(
foo='bar',
one='two'
)
self.assertEqual(
['bar'],
args.from_params(datatype, params, 'foo', hit_paths)
)
self.assertEqual({'foo'}, hit_paths)
self.assertEqual(
['two'],
args.array_from_params(datatype, params, 'one', hit_paths)
)
self.assertEqual({'foo', 'one'}, hit_paths)
def test_usertype_from_params(self):
hit_paths = set()
datatype = types.listtype
params = multidict.MultiDict(
foo='0,1, 2, three',
)
self.assertEqual(
['0', '1', '2', 'three'],
args.usertype_from_params(datatype, params, 'foo', hit_paths)
)
self.assertEqual(
['0', '1', '2', 'three'],
args.from_params(datatype, params, 'foo', hit_paths)
)
self.assertEqual(
atypes.Unset,
args.usertype_from_params(datatype, params, 'bar', hit_paths)
)
def test_args_from_args(self):
fromargs = ['one', 2, [0, '1', '2_34']]
fromkwargs = {'foo': '1, 2, 3'}
@functions.signature(str, str, int, atypes.ArrayType(int),
types.listtype)
def myfunc(self, first, second, third, foo):
pass
funcdef = functions.FunctionDefinition.get(myfunc)
newargs, newkwargs = args.args_from_args(funcdef, fromargs, fromkwargs)
self.assertEqual(['one', 2, [0, 1, 234]], newargs)
self.assertEqual({'foo': ['1', '2', '3']}, newkwargs)
def test_args_from_params(self):
@functions.signature(str, str, int, atypes.ArrayType(int),
types.listtype)
def myfunc(self, first, second, third, foo):
pass
funcdef = functions.FunctionDefinition.get(myfunc)
params = multidict.MultiDict(
foo='0,1, 2, three',
third='1',
second='2'
)
self.assertEqual(
([], {'foo': ['0', '1', '2', 'three'], 'second': 2, 'third': [1]}),
args.args_from_params(funcdef, params)
)
# unexpected param
params = multidict.MultiDict(bar='baz')
self.assertRaises(exception.UnknownArgument, args.args_from_params,
funcdef, params)
# no params plus a body
params = multidict.MultiDict(__body__='')
self.assertEqual(
([], {}),
args.args_from_params(funcdef, params)
)
def test_args_from_body(self):
@functions.signature(str, body=NestedObj)
def myfunc(self, nested):
pass
funcdef = functions.FunctionDefinition.get(myfunc)
mimetype = 'application/json'
body = b'{"o": {"id": 1234, "name": "an object"}}'
newargs, newkwargs = args.args_from_body(funcdef, body, mimetype)
self.assertEqual(1234, newkwargs['nested'].o.id)
self.assertEqual('an object', newkwargs['nested'].o.name)
self.assertEqual(
((), {}),
args.args_from_body(funcdef, None, mimetype)
)
self.assertRaises(exception.ClientSideError, args.args_from_body,
funcdef, body, 'application/x-corba')
self.assertEqual(
((), {}),
args.args_from_body(funcdef, body,
'application/x-www-form-urlencoded')
)
def test_combine_args(self):
@functions.signature(str, str, int)
def myfunc(self, first, second,):
pass
funcdef = functions.FunctionDefinition.get(myfunc)
# empty
self.assertEqual(
([], {}),
args.combine_args(
funcdef, (
([], {}),
([], {}),
)
)
)
# combine kwargs
self.assertEqual(
([], {'first': 'one', 'second': 'two'}),
args.combine_args(
funcdef, (
([], {}),
([], {'first': 'one', 'second': 'two'}),
)
)
)
# combine mixed args
self.assertEqual(
([], {'first': 'one', 'second': 'two'}),
args.combine_args(
funcdef, (
(['one'], {}),
([], {'second': 'two'}),
)
)
)
# override kwargs
self.assertEqual(
([], {'first': 'two'}),
args.combine_args(
funcdef, (
([], {'first': 'one'}),
([], {'first': 'two'}),
),
allow_override=True
)
)
# override args
self.assertEqual(
([], {'first': 'two', 'second': 'three'}),
args.combine_args(
funcdef, (
(['one', 'three'], {}),
(['two'], {}),
),
allow_override=True
)
)
# can't override args
self.assertRaises(exception.ClientSideError, args.combine_args,
funcdef,
((['one'], {}), (['two'], {})))
# can't override kwargs
self.assertRaises(exception.ClientSideError, args.combine_args,
funcdef,
(([], {'first': 'one'}), ([], {'first': 'two'})))
def test_get_args(self):
@functions.signature(str, str, int, atypes.ArrayType(int),
types.listtype, body=NestedObj)
def myfunc(self, first, second, third, foo, nested):
pass
funcdef = functions.FunctionDefinition.get(myfunc)
params = multidict.MultiDict(
foo='0,1, 2, three',
second='2'
)
mimetype = 'application/json'
body = b'{"o": {"id": 1234, "name": "an object"}}'
fromargs = ['one']
fromkwargs = {'third': '1'}
newargs, newkwargs = args.get_args(funcdef, fromargs, fromkwargs,
params, body, mimetype)
self.assertEqual([], newargs)
n = newkwargs.pop('nested')
self.assertEqual({
'first': 'one',
'foo': ['0', '1', '2', 'three'],
'second': 2,
'third': [1]},
newkwargs
)
self.assertEqual(1234, n.o.id)
self.assertEqual('an object', n.o.name)
# check_arguments missing mandatory argument 'second'
params = multidict.MultiDict(
foo='0,1, 2, three',
)
self.assertRaises(exception.MissingArgument, args.get_args,
funcdef, fromargs, fromkwargs,
params, body, mimetype)