Modification to add additional information in the HTTPCheck request.

make it easier to reuse the invocation logic for check objects

Provide a new private function in oslo_policy._checks to evaluate a
check object. This function protects against API changes to the check
classes by inspecting the set of arguments accepted.

Update Enforcer to use the new function instead of invoking checks
directly.

Update the nested check classes (and, or, not) to use the new function
instead of invoking their sub-rules directly.

Update the way mocks were being used in some tests to replace them
with real minimal classes that implement the necessary APIs.

Simplify a few tests that were confirming multiple behaviors (for
example, the result of a compound check as well as the arguments
passed to its nested rules).

Ensure that we have test cases for invoking nested rules that do and
do not accept the new current_rule argument.

Change-Id: Ib9edd7954d0b977950be536fa9434243b0de7fcf
Signed-off-by: Doug Hellmann <doug@doughellmann.com>
This commit is contained in:
Thomas Duval 2017-08-28 15:40:02 +02:00
parent 7b1a6c16bd
commit 70ba1beb3e
3 changed files with 269 additions and 47 deletions

View File

@ -19,6 +19,7 @@ import abc
import ast
import contextlib
import copy
import inspect
from oslo_serialization import jsonutils
import requests
@ -28,6 +29,49 @@ import six
registered_checks = {}
def _check(rule, target, creds, enforcer, current_rule):
"""Evaluate the rule.
This private method is meant to be used by the enforcer to call
the rule. It can also be used by built-in checks that have nested
rules.
We use a private function because it makes it easier to change the
API without having an impact on subclasses not defined within the
oslo.policy library.
We don't put this logic in Enforcer.enforce() and invoke that
method recursively because that changes the BaseCheck API to
require that the enforcer argument to __call__() be a valid
Enforcer instance (as evidenced by all of the breaking unit
tests).
We don't put this in a private method of BaseCheck because that
propagates the problem of extending the list of arguments to
__call__() if subclasses change the implementation of the
function.
:param rule: A check object.
:type rule: BaseCheck
:param target: Attributes of the object of the operation.
:type target: dict
:param creds: Attributes of the user performing the operation.
:type creds: dict
:param enforcer: The Enforcer being used.
:type enforcer: Enforcer
:param current_rule: The name of the policy being checked.
:type current_rule: str
"""
# Evaluate the rule
argspec = inspect.getargspec(rule.__call__)
rule_args = [target, creds, enforcer]
# Check if the rule argument must be included or not
if len(argspec.args) > 4:
rule_args.append(current_rule)
return rule(*rule_args)
@six.add_metaclass(abc.ABCMeta)
class BaseCheck(object):
"""Abstract base class for Check classes."""
@ -39,7 +83,7 @@ class BaseCheck(object):
pass
@abc.abstractmethod
def __call__(self, target, cred, enforcer):
def __call__(self, target, cred, enforcer, current_rule=None):
"""Triggers if instance of the class is called.
Performs the check. Returns False to reject the access or a
@ -57,7 +101,7 @@ class FalseCheck(BaseCheck):
return '!'
def __call__(self, target, cred, enforcer):
def __call__(self, target, cred, enforcer, current_rule=None):
"""Check the policy."""
return False
@ -71,7 +115,7 @@ class TrueCheck(BaseCheck):
return '@'
def __call__(self, target, cred, enforcer):
def __call__(self, target, cred, enforcer, current_rule=None):
"""Check the policy."""
return True
@ -97,13 +141,13 @@ class NotCheck(BaseCheck):
return 'not %s' % self.rule
def __call__(self, target, cred, enforcer):
def __call__(self, target, cred, enforcer, current_rule=None):
"""Check the policy.
Returns the logical inverse of the wrapped check.
"""
return not self.rule(target, cred, enforcer)
return not _check(self.rule, target, cred, enforcer, current_rule)
class AndCheck(BaseCheck):
@ -115,14 +159,14 @@ class AndCheck(BaseCheck):
return '(%s)' % ' and '.join(str(r) for r in self.rules)
def __call__(self, target, cred, enforcer):
def __call__(self, target, cred, enforcer, current_rule=None):
"""Check the policy.
Requires that all rules accept in order to return True.
"""
for rule in self.rules:
if not rule(target, cred, enforcer):
if not _check(rule, target, cred, enforcer, current_rule):
return False
return True
@ -150,14 +194,14 @@ class OrCheck(BaseCheck):
return '(%s)' % ' or '.join(str(r) for r in self.rules)
def __call__(self, target, cred, enforcer):
def __call__(self, target, cred, enforcer, current_rule=None):
"""Check the policy.
Requires that at least one rule accept in order to return True.
"""
for rule in self.rules:
if rule(target, cred, enforcer):
if _check(rule, target, cred, enforcer, current_rule):
return True
return False
@ -199,9 +243,15 @@ def register(name, func=None):
@register('rule')
class RuleCheck(Check):
def __call__(self, target, creds, enforcer):
def __call__(self, target, creds, enforcer, current_rule=None):
try:
return enforcer.rules[self.match](target, creds, enforcer)
return _check(
rule=enforcer.rules[self.match],
target=target,
creds=creds,
enforcer=enforcer,
current_rule=current_rule,
)
except KeyError:
# We don't have any matching rule; fail closed
return False
@ -211,7 +261,7 @@ class RuleCheck(Check):
class RoleCheck(Check):
"""Check that there is a matching role in the ``creds`` dict."""
def __call__(self, target, creds, enforcer):
def __call__(self, target, creds, enforcer, current_rule=None):
try:
match = self.match % target
except KeyError:
@ -231,7 +281,7 @@ class HttpCheck(Check):
is exactly ``True``.
"""
def __call__(self, target, creds, enforcer):
def __call__(self, target, creds, enforcer, current_rule=None):
url = ('http:' + self.match) % target
# Convert instances of object() in target temporarily to
@ -242,7 +292,8 @@ class HttpCheck(Check):
element = target.get(key)
if type(element) is object:
temp_target[key] = {}
data = {'target': jsonutils.dumps(temp_target),
data = {'rule': jsonutils.dumps(current_rule),
'target': jsonutils.dumps(temp_target),
'credentials': jsonutils.dumps(creds)}
with contextlib.closing(requests.post(url, data=data)) as r:
return r.text == 'True'
@ -291,7 +342,7 @@ class GenericCheck(Check):
else:
return self._find_in_dict(test_value, path_segments, match)
def __call__(self, target, creds, enforcer):
def __call__(self, target, creds, enforcer, current_rule=None):
try:
match = self.match % target

View File

@ -729,18 +729,33 @@ class Enforcer(object):
# Allow the rule to be a Check tree
if isinstance(rule, _checks.BaseCheck):
result = rule(target, creds, self)
# If the thing we're given is a Check, we don't know the
# name of the rule, so pass None for current_rule.
result = _checks._check(
rule=rule,
target=target,
creds=creds,
enforcer=self,
current_rule=None,
)
elif not self.rules:
# No rules to reference means we're going to fail closed
result = False
else:
try:
# Evaluate the rule
result = self.rules[rule](target, creds, self)
to_check = self.rules[rule]
except KeyError:
LOG.debug('Rule [%s] does not exist', rule)
# If the rule doesn't exist, fail closed
result = False
else:
result = _checks._check(
rule=to_check,
target=target,
creds=creds,
enforcer=self,
current_rule=rule,
)
# If it is False, raise the exception if requested
if do_raise and not result:

View File

@ -51,21 +51,17 @@ class RuleCheckTestCase(base.PolicyBaseTestCase):
self.assertFalse(check('target', 'creds', self.enforcer))
def test_rule_false(self):
self.enforcer.rules = dict(spam=mock.Mock(return_value=False))
self.enforcer.rules = dict(spam=_BoolCheck(False))
check = _checks.RuleCheck('rule', 'spam')
self.assertFalse(check('target', 'creds', self.enforcer))
self.enforcer.rules['spam'].assert_called_once_with('target', 'creds',
self.enforcer)
def test_rule_true(self):
self.enforcer.rules = dict(spam=mock.Mock(return_value=True))
self.enforcer.rules = dict(spam=_BoolCheck(True))
check = _checks.RuleCheck('rule', 'spam')
self.assertTrue(check('target', 'creds', self.enforcer))
self.enforcer.rules['spam'].assert_called_once_with('target', 'creds',
self.enforcer)
class RoleCheckTestCase(base.PolicyBaseTestCase):
@ -122,7 +118,8 @@ class HttpCheckTestCase(base.PolicyBaseTestCase):
last_request = self.requests_mock.last_request
self.assertEqual('POST', last_request.method)
self.assertEqual(dict(target=target_dict, credentials=cred_dict),
self.assertEqual(dict(target=target_dict, credentials=cred_dict,
rule=None),
self.decode_post_data(last_request.body))
def test_reject(self):
@ -136,7 +133,8 @@ class HttpCheckTestCase(base.PolicyBaseTestCase):
last_request = self.requests_mock.last_request
self.assertEqual('POST', last_request.method)
self.assertEqual(dict(target=target_dict, credentials=cred_dict),
self.assertEqual(dict(target=target_dict, credentials=cred_dict,
rule=None),
self.decode_post_data(last_request.body))
def test_http_with_objects_in_target(self):
@ -161,6 +159,40 @@ class HttpCheckTestCase(base.PolicyBaseTestCase):
dict(user='user', roles=['a', 'b', 'c']),
self.enforcer))
def test_accept_with_rule_in_argument(self):
self.requests_mock.post('http://example.com/target', text='True')
check = _checks.HttpCheck('http', '//example.com/%(name)s')
target_dict = dict(name='target', spam='spammer')
cred_dict = dict(user='user', roles=['a', 'b', 'c'])
current_rule = "a_rule"
self.assertTrue(check(target_dict, cred_dict, self.enforcer,
current_rule))
last_request = self.requests_mock.last_request
self.assertEqual('POST', last_request.method)
self.assertEqual(dict(target=target_dict, credentials=cred_dict,
rule=current_rule),
self.decode_post_data(last_request.body))
def test_reject_with_rule_in_argument(self):
self.requests_mock.post("http://example.com/target", text='other')
check = _checks.HttpCheck('http', '//example.com/%(name)s')
target_dict = dict(name='target', spam='spammer')
cred_dict = dict(user='user', roles=['a', 'b', 'c'])
current_rule = "a_rule"
self.assertFalse(check(target_dict, cred_dict, self.enforcer,
current_rule))
last_request = self.requests_mock.last_request
self.assertEqual('POST', last_request.method)
self.assertEqual(dict(target=target_dict, credentials=cred_dict,
rule=current_rule),
self.decode_post_data(last_request.body))
class GenericCheckTestCase(base.PolicyBaseTestCase):
def test_no_cred(self):
@ -339,18 +371,60 @@ class NotCheckTestCase(test_base.BaseTestCase):
self.assertEqual('not rule', str(check))
def test_call_true(self):
rule = mock.Mock(return_value=True)
rule = _checks.TrueCheck()
check = _checks.NotCheck(rule)
self.assertFalse(check('target', 'cred', None))
rule.assert_called_once_with('target', 'cred', None)
def test_call_false(self):
rule = mock.Mock(return_value=False)
rule = _checks.FalseCheck()
check = _checks.NotCheck(rule)
self.assertTrue(check('target', 'cred', None))
rule.assert_called_once_with('target', 'cred', None)
def test_rule_takes_current_rule(self):
results = []
class TestCheck(object):
def __call__(self, target, cred, enforcer, current_rule=None):
results.append((target, cred, enforcer, current_rule))
return True
check = _checks.NotCheck(TestCheck())
self.assertFalse(check('target', 'cred', None, current_rule="a_rule"))
self.assertEqual(
[('target', 'cred', None, 'a_rule')],
results,
)
def test_rule_does_not_take_current_rule(self):
results = []
class TestCheck(object):
def __call__(self, target, cred, enforcer):
results.append((target, cred, enforcer))
return True
check = _checks.NotCheck(TestCheck())
self.assertFalse(check('target', 'cred', None, current_rule="a_rule"))
self.assertEqual(
[('target', 'cred', None)],
results,
)
class _BoolCheck(_checks.BaseCheck):
def __init__(self, result):
self.called = False
self.result = result
def __str__(self):
return str(self.result)
def __call__(self, target, creds, enforcer, current_rule=None):
self.called = True
return self.result
class AndCheckTestCase(test_base.BaseTestCase):
@ -371,29 +445,70 @@ class AndCheckTestCase(test_base.BaseTestCase):
self.assertEqual('(rule1 and rule2)', str(check))
def test_call_all_false(self):
rules = [mock.Mock(return_value=False), mock.Mock(return_value=False)]
rules = [
_BoolCheck(False),
_BoolCheck(False),
]
check = _checks.AndCheck(rules)
self.assertFalse(check('target', 'cred', None))
rules[0].assert_called_once_with('target', 'cred', None)
self.assertTrue(rules[0].called)
self.assertFalse(rules[1].called)
def test_call_first_true(self):
rules = [mock.Mock(return_value=True), mock.Mock(return_value=False)]
rules = [
_BoolCheck(True),
_BoolCheck(False),
]
check = _checks.AndCheck(rules)
self.assertFalse(check('target', 'cred', None))
rules[0].assert_called_once_with('target', 'cred', None)
rules[1].assert_called_once_with('target', 'cred', None)
self.assertTrue(rules[0].called)
self.assertTrue(rules[1].called)
def test_call_second_true(self):
rules = [mock.Mock(return_value=False), mock.Mock(return_value=True)]
rules = [
_BoolCheck(False),
_BoolCheck(True),
]
check = _checks.AndCheck(rules)
self.assertFalse(check('target', 'cred', None))
rules[0].assert_called_once_with('target', 'cred', None)
self.assertTrue(rules[0].called)
self.assertFalse(rules[1].called)
def test_rule_takes_current_rule(self):
results = []
class TestCheck(object):
def __call__(self, target, cred, enforcer, current_rule=None):
results.append((target, cred, enforcer, current_rule))
return False
check = _checks.AndCheck([TestCheck()])
self.assertFalse(check('target', 'cred', None, current_rule="a_rule"))
self.assertEqual(
[('target', 'cred', None, 'a_rule')],
results,
)
def test_rule_does_not_take_current_rule(self):
results = []
class TestCheck(object):
def __call__(self, target, cred, enforcer):
results.append((target, cred, enforcer))
return False
check = _checks.AndCheck([TestCheck()])
self.assertFalse(check('target', 'cred', None, current_rule="a_rule"))
self.assertEqual(
[('target', 'cred', None)],
results,
)
class OrCheckTestCase(test_base.BaseTestCase):
def test_init(self):
@ -420,25 +535,66 @@ class OrCheckTestCase(test_base.BaseTestCase):
self.assertEqual('(rule1 or rule2)', str(check))
def test_call_all_false(self):
rules = [mock.Mock(return_value=False), mock.Mock(return_value=False)]
rules = [
_BoolCheck(False),
_BoolCheck(False),
]
check = _checks.OrCheck(rules)
self.assertFalse(check('target', 'cred', None))
rules[0].assert_called_once_with('target', 'cred', None)
rules[1].assert_called_once_with('target', 'cred', None)
self.assertTrue(rules[0].called)
self.assertTrue(rules[1].called)
def test_call_first_true(self):
rules = [mock.Mock(return_value=True), mock.Mock(return_value=False)]
rules = [
_BoolCheck(True),
_BoolCheck(False),
]
check = _checks.OrCheck(rules)
self.assertTrue(check('target', 'cred', None))
rules[0].assert_called_once_with('target', 'cred', None)
self.assertTrue(rules[0].called)
self.assertFalse(rules[1].called)
def test_call_second_true(self):
rules = [mock.Mock(return_value=False), mock.Mock(return_value=True)]
rules = [
_BoolCheck(False),
_BoolCheck(True),
]
check = _checks.OrCheck(rules)
self.assertTrue(check('target', 'cred', None))
rules[0].assert_called_once_with('target', 'cred', None)
rules[1].assert_called_once_with('target', 'cred', None)
self.assertTrue(rules[0].called)
self.assertTrue(rules[1].called)
def test_rule_takes_current_rule(self):
results = []
class TestCheck(object):
def __call__(self, target, cred, enforcer, current_rule=None):
results.append((target, cred, enforcer, current_rule))
return False
check = _checks.OrCheck([TestCheck()])
self.assertFalse(check('target', 'cred', None, current_rule="a_rule"))
self.assertEqual(
[('target', 'cred', None, 'a_rule')],
results,
)
def test_rule_does_not_take_current_rule(self):
results = []
class TestCheck(object):
def __call__(self, target, cred, enforcer):
results.append((target, cred, enforcer))
return False
check = _checks.OrCheck([TestCheck()])
self.assertFalse(check('target', 'cred', None, current_rule="a_rule"))
self.assertEqual(
[('target', 'cred', None)],
results,
)