Complex types should check unexpected attributes

Check if a request is passing more attributes for complex objects than
those defined in the API. WSME did not care if some unknown attribute
was passed to an API with a complex type, only checking the required
attributes. This is fixed raising a ValueError if more attributes are
given than expected, resulting in an HTTP response with a 400 status.
This helps check the validity of requests, which would otherwise
unexpectedly work with (partially) invalid data.

Closes-Bug: #1277571
Change-Id: Idf720a1c3fac8bdc8dca21a1ccdb126110dae62e
This commit is contained in:
Stéphane Bisinger 2015-05-13 19:08:56 +02:00
parent f28ec1354e
commit 6461c1b8a1
3 changed files with 129 additions and 17 deletions

View File

@ -62,3 +62,31 @@ class UnknownFunction(ClientSideError):
@property @property
def faultstring(self): def faultstring(self):
return _(six.u("Unknown function name: %s")) % (self.name) return _(six.u("Unknown function name: %s")) % (self.name)
class UnknownAttribute(ClientSideError):
def __init__(self, fieldname, attributes, msg=''):
self.fieldname = fieldname
self.attributes = attributes
self.msg = msg
super(UnknownAttribute, self).__init__(self.msg)
@property
def faultstring(self):
error = _("Unknown attribute for argument %(argn)s: %(attrs)s")
if len(self.attributes) > 1:
error = _("Unknown attributes for argument %(argn)s: %(attrs)s")
str_attrs = ", ".join(self.attributes)
return error % {'argn': self.fieldname, 'attrs': str_attrs}
def add_fieldname(self, name):
"""Add a fieldname to concatenate the full name.
Add a fieldname so that the whole hierarchy is displayed. Successive
calls to this method will prepend ``name`` to the hierarchy of names.
"""
if self.fieldname is not None:
self.fieldname = "{}.{}".format(name, self.fieldname)
else:
self.fieldname = name
super(UnknownAttribute, self).__init__(self.msg)

View File

@ -1,6 +1,4 @@
""" """REST+Json protocol implementation."""
REST+Json protocol implementation.
"""
from __future__ import absolute_import from __future__ import absolute_import
import datetime import datetime
import decimal import decimal
@ -9,10 +7,10 @@ import six
from simplegeneric import generic from simplegeneric import generic
from wsme.types import Unset import wsme.exc
import wsme.types import wsme.types
from wsme.types import Unset
import wsme.utils import wsme.utils
from wsme.exc import ClientSideError, UnknownArgument, InvalidInput
try: try:
@ -116,8 +114,7 @@ def datetime_tojson(datatype, value):
@generic @generic
def fromjson(datatype, value): def fromjson(datatype, value):
""" """A generic converter from json base types to python datatype.
A generic converter from json base types to python datatype.
If a non-complex user specific type is to be used in the api, If a non-complex user specific type is to be used in the api,
a specific fromjson should be added:: a specific fromjson should be added::
@ -135,16 +132,31 @@ def fromjson(datatype, value):
return None return None
if wsme.types.iscomplex(datatype): if wsme.types.iscomplex(datatype):
obj = datatype() obj = datatype()
for attrdef in wsme.types.list_attributes(datatype): attributes = wsme.types.list_attributes(datatype)
# Here we check that all the attributes in the value are also defined
# in our type definition, otherwise we raise an Error.
v_keys = set(value.keys())
a_keys = set(adef.name for adef in attributes)
if not v_keys <= a_keys:
raise wsme.exc.UnknownAttribute(None, v_keys - a_keys)
for attrdef in attributes:
if attrdef.name in value: if attrdef.name in value:
val_fromjson = fromjson(attrdef.datatype, value[attrdef.name]) try:
val_fromjson = fromjson(attrdef.datatype,
value[attrdef.name])
except wsme.exc.UnknownAttribute as e:
e.add_fieldname(attrdef.name)
raise
if getattr(attrdef, 'readonly', False): if getattr(attrdef, 'readonly', False):
raise InvalidInput(attrdef.name, val_fromjson, raise wsme.exc.InvalidInput(attrdef.name, val_fromjson,
"Cannot set read only field.") "Cannot set read only field.")
setattr(obj, attrdef.key, val_fromjson) setattr(obj, attrdef.key, val_fromjson)
elif attrdef.mandatory: elif attrdef.mandatory:
raise InvalidInput(attrdef.name, None, raise wsme.exc.InvalidInput(attrdef.name, None,
"Mandatory field missing.") "Mandatory field missing.")
return wsme.types.validate_value(datatype, obj) return wsme.types.validate_value(datatype, obj)
elif wsme.types.isusertype(datatype): elif wsme.types.isusertype(datatype):
value = datatype.frombasetype( value = datatype.frombasetype(
@ -243,13 +255,18 @@ def parse(s, datatypes, bodyarg, encoding='utf8'):
try: try:
jdata = jload(s) jdata = jload(s)
except ValueError: except ValueError:
raise ClientSideError("Request is not in valid JSON format") raise wsme.exc.ClientSideError("Request is not in valid JSON format")
if bodyarg: if bodyarg:
argname = list(datatypes.keys())[0] argname = list(datatypes.keys())[0]
try: try:
kw = {argname: fromjson(datatypes[argname], jdata)} kw = {argname: fromjson(datatypes[argname], jdata)}
except ValueError as e: except ValueError as e:
raise InvalidInput(argname, jdata, e.args[0]) raise wsme.exc.InvalidInput(argname, jdata, e.args[0])
except wsme.exc.UnknownAttribute as e:
# We only know the fieldname at this level, not in the
# called function. We fill in this information here.
e.add_fieldname(argname)
raise
else: else:
kw = {} kw = {}
extra_args = [] extra_args = []
@ -260,9 +277,14 @@ def parse(s, datatypes, bodyarg, encoding='utf8'):
try: try:
kw[key] = fromjson(datatypes[key], jdata[key]) kw[key] = fromjson(datatypes[key], jdata[key])
except ValueError as e: except ValueError as e:
raise InvalidInput(key, jdata[key], e.args[0]) raise wsme.exc.InvalidInput(key, jdata[key], e.args[0])
except wsme.exc.UnknownAttribute as e:
# We only know the fieldname at this level, not in the
# called function. We fill in this information here.
e.add_fieldname(key)
raise
if extra_args: if extra_args:
raise UnknownArgument(', '.join(extra_args)) raise wsme.exc.UnknownArgument(', '.join(extra_args))
return kw return kw

View File

@ -100,6 +100,10 @@ class Obj(wsme.types.Base):
name = wsme.types.text name = wsme.types.text
class NestedObj(wsme.types.Base):
o = Obj
class CRUDResult(object): class CRUDResult(object):
data = Obj data = Obj
message = wsme.types.text message = wsme.types.text
@ -476,6 +480,37 @@ class TestRestJson(wsme.tests.protocol.RestOnlyProtocolTestCase):
"invalid literal for int() with base 10: '%s'" % value "invalid literal for int() with base 10: '%s'" % value
) )
def test_parse_unexpected_attribute(self):
o = {
"id": "1",
"name": "test",
"other": "unknown",
"other2": "still unknown",
}
for ba in True, False:
jd = o if ba else {"o": o}
try:
parse(json.dumps(jd), {'o': Obj}, ba)
raise AssertionError("Object should not parse correcty.")
except wsme.exc.UnknownAttribute as e:
self.assertEqual(e.attributes, set(['other', 'other2']))
def test_parse_unexpected_nested_attribute(self):
no = {
"o": {
"id": "1",
"name": "test",
"other": "unknown",
},
}
for ba in False, True:
jd = no if ba else {"no": no}
try:
parse(json.dumps(jd), {'no': NestedObj}, ba)
except wsme.exc.UnknownAttribute as e:
self.assertEqual(e.attributes, set(['other']))
self.assertEqual(e.fieldname, "no.o")
def test_nest_result(self): def test_nest_result(self):
self.root.protocols[0].nest_result = True self.root.protocols[0].nest_result = True
r = self.app.get('/returntypes/getint.json') r = self.app.get('/returntypes/getint.json')
@ -659,6 +694,33 @@ class TestRestJson(wsme.tests.protocol.RestOnlyProtocolTestCase):
assert result['data']['name'] == u("test") assert result['data']['name'] == u("test")
assert result['message'] == "read" assert result['message'] == "read"
def test_unexpected_extra_arg(self):
headers = {
'Content-Type': 'application/json',
}
data = {"id": 1, "name": "test"}
content = json.dumps({"data": data, "other": "unexpected"})
res = self.app.put(
'/crud',
content,
headers=headers,
expect_errors=True)
self.assertEqual(res.status_int, 400)
def test_unexpected_extra_attribute(self):
"""Expect a failure if we send an unexpected object attribute."""
headers = {
'Content-Type': 'application/json',
}
data = {"id": 1, "name": "test", "other": "unexpected"}
content = json.dumps({"data": data})
res = self.app.put(
'/crud',
content,
headers=headers,
expect_errors=True)
self.assertEqual(res.status_int, 400)
def test_body_arg(self): def test_body_arg(self):
headers = { headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',