Multiple protocol accept or content-type matching

The changes in 8710dabb65 broke
protocol failover when the REST protocol is listed before others
(see bug referenced below).

This patch tries to solve both issues by trying to match accept over all
the protocols, only giving up a 406 or 415 if all protocols fail, using
the last failure as the error message.

Related-Bug: #1419110
Closes-Bug: #1442710
Change-Id: I328a392151013c46207519c245213d5dec48ecc9
This commit is contained in:
Chris Dent 2015-04-10 18:41:32 +01:00 committed by Stéphane Bisinger
parent f66cf4c3cc
commit e31045e57a
4 changed files with 100 additions and 13 deletions

View File

@ -117,7 +117,7 @@ def getprotocol(name, **options):
def media_type_accept(request, content_types):
"""Return True if the requested media type is available.
"""Validate media types against request.method.
When request.method is GET or HEAD compare with the Accept header.
When request.method is POST, PUT or PATCH compare with the Content-Type
@ -131,7 +131,6 @@ def media_type_accept(request, content_types):
error_message = ('Unacceptable Accept type: %s not in %s'
% (request.accept, content_types))
raise ClientSideError(error_message, status_code=406)
return False
elif request.method in ['PUT', 'POST', 'PATCH']:
content_type = request.headers.get('Content-Type')
if content_type:

View File

@ -149,6 +149,7 @@ class WSRoot(object):
request.body[:512] or
request.body) or '')
protocol = None
error = ClientSideError(status_code=406)
path = str(request.path)
assert path.startswith(self._webpath)
path = path[len(self._webpath) + 1:]
@ -157,9 +158,16 @@ class WSRoot(object):
else:
for p in self.protocols:
if p.accept(request):
protocol = p
break
try:
if p.accept(request):
protocol = p
break
except ClientSideError as e:
error = e
# If we could not select a protocol, we raise the last exception
# that we got, or the default one.
if not protocol:
raise error
return protocol
def _do_call(self, protocol, context):
@ -232,11 +240,6 @@ class WSRoot(object):
msg = None
error_status = 500
protocol = self._select_protocol(request)
if protocol is None:
if request.method in ['GET', 'HEAD']:
error_status = 406
elif request.method in ['POST', 'PUT', 'PATCH']:
error_status = 415
except ClientSideError as e:
error_status = e.code
msg = e.faultstring
@ -248,7 +251,7 @@ class WSRoot(object):
error_status = 500
if protocol is None:
if msg is None:
if not msg:
msg = ("None of the following protocols can handle this "
"request : %s" % ','.join((
p.name for p in self.protocols)))
@ -296,6 +299,10 @@ class WSRoot(object):
else:
res.status = protocol.get_response_status(request)
res_content_type = protocol.get_response_contenttype(request)
except ClientSideError as e:
request.server_errorcount += 1
res.status = e.code
res.text = e.faultstring
except Exception:
infos = wsme.api.format_exception(sys.exc_info(), self._debug)
request.server_errorcount += 1

View File

@ -201,6 +201,7 @@ Value should be one of:"))
app = webtest.TestApp(r.wsgiapp())
res = app.get('/', expect_errors=True)
print(res.status_int)
assert res.status_int == 406
print(res.body)
assert res.body.find(

View File

@ -3,9 +3,12 @@
import unittest
from wsme import WSRoot
import wsme.protocol
import wsme.rest.protocol
from wsme.root import default_prepare_response_body
from six import b, u
from webob import Request
class TestRoot(unittest.TestCase):
@ -24,9 +27,9 @@ class TestRoot(unittest.TestCase):
default_prepare_response_body(None, [u('a'), u('b')]) == u('a\nb')
def test_protocol_selection_error(self):
import wsme.protocol
class P(wsme.protocol.Protocol):
name = "test"
def accept(self, r):
raise Exception('test')
@ -40,3 +43,80 @@ class TestRoot(unittest.TestCase):
assert res.content_type == 'text/plain'
assert (res.text ==
'Unexpected error while selecting protocol: test'), req.text
def test_protocol_selection_accept_mismatch(self):
"""Verify that we get a 406 error on wrong Accept header."""
class P(wsme.protocol.Protocol):
name = "test"
def accept(self, r):
return False
root = WSRoot()
root.addprotocol(wsme.rest.protocol.RestProtocol())
root.addprotocol(P())
req = Request.blank('/test?check=a&check=b&name=Bob')
req.method = 'GET'
res = root._handle_request(req)
assert res.status_int == 406
assert res.content_type == 'text/plain'
assert res.text.startswith(
'None of the following protocols can handle this request'
), req.text
def test_protocol_selection_content_type_mismatch(self):
"""Verify that we get a 415 error on wrong Content-Type header."""
class P(wsme.protocol.Protocol):
name = "test"
def accept(self, r):
return False
root = WSRoot()
root.addprotocol(wsme.rest.protocol.RestProtocol())
root.addprotocol(P())
req = Request.blank('/test?check=a&check=b&name=Bob')
req.method = 'POST'
req.headers['Content-Type'] = "test/unsupported"
res = root._handle_request(req)
assert res.status_int == 415
assert res.content_type == 'text/plain'
assert res.text.startswith(
'Unacceptable Content-Type: test/unsupported not in'
), req.text
def test_protocol_selection_get_method(self):
class P(wsme.protocol.Protocol):
name = "test"
def accept(self, r):
return True
root = WSRoot()
root.addprotocol(wsme.rest.protocol.RestProtocol())
root.addprotocol(P())
req = Request.blank('/test?check=a&check=b&name=Bob')
req.method = 'GET'
req.headers['Accept'] = 'test/fake'
p = root._select_protocol(req)
assert p.name == "test"
def test_protocol_selection_post_method(self):
class P(wsme.protocol.Protocol):
name = "test"
def accept(self, r):
return True
root = WSRoot()
root.addprotocol(wsme.rest.protocol.RestProtocol())
root.addprotocol(P())
req = Request.blank('/test?check=a&check=b&name=Bob')
req.headers['Content-Type'] = 'test/fake'
req.method = 'POST'
p = root._select_protocol(req)
assert p.name == "test"