trove/trove/tests/util/check.py

205 lines
7.0 KiB
Python

# Copyright (c) 2012 OpenStack
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Like asserts, but does not raise an exception until the end of a block."""
import traceback
from proboscis.asserts import assert_equal
from proboscis.asserts import assert_false
from proboscis.asserts import assert_not_equal
from proboscis.asserts import assert_true
from proboscis.asserts import ASSERTION_ERROR
from proboscis.asserts import Check
def get_stack_trace_of_caller(level_up):
"""Gets the stack trace at the point of the caller."""
level_up += 1
st = traceback.extract_stack()
caller_index = len(st) - level_up
if caller_index < 0:
caller_index = 0
new_st = st[0:caller_index]
return new_st
def raise_blame_caller(level_up, ex):
"""Raises an exception, changing the stack trace to point to the caller."""
new_st = get_stack_trace_of_caller(level_up + 2)
raise ex.with_traceback(new_st)
class Checker(object):
def __init__(self):
self.messages = []
self.odd = True
self.protected = False
def _add_exception(self, _type, value, tb):
"""Takes an exception, and adds it as a string."""
if self.odd:
prefix = "* "
else:
prefix = "- "
start = "Check failure! Traceback:"
middle = prefix.join(traceback.format_list(tb))
end = '\n'.join(traceback.format_exception_only(_type, value))
msg = '\n'.join([start, middle, end])
self.messages.append(msg)
self.odd = not self.odd
def equal(self, *args, **kwargs):
self._run_assertion(assert_equal, *args, **kwargs)
def false(self, *args, **kwargs):
self._run_assertion(assert_false, *args, **kwargs)
def not_equal(self, *args, **kwargs):
_run_assertion(assert_not_equal, *args, **kwargs)
def _run_assertion(self, assert_func, *args, **kwargs):
"""
Runs an assertion method, but catches any failure and adds it as a
string to the messages list.
"""
if self.protected:
try:
assert_func(*args, **kwargs)
except ASSERTION_ERROR as ae:
st = get_stack_trace_of_caller(2)
self._add_exception(ASSERTION_ERROR, ae, st)
else:
assert_func(*args, **kwargs)
def __enter__(self):
self.protected = True
return self
def __exit__(self, _type, value, tb):
self.protected = False
if _type is not None:
# An error occurred other than an assertion failure.
# Return False to allow the Exception to be raised
return False
if len(self.messages) != 0:
final_message = '\n'.join(self.messages)
raise ASSERTION_ERROR(final_message)
def true(self, *args, **kwargs):
self._run_assertion(assert_true, *args, **kwargs)
class AttrCheck(Check):
"""Class for attr checks, links and other common items."""
def __init__(self):
super(AttrCheck, self).__init__()
def fail(self, msg):
self.true(False, msg)
def contains_allowed_attrs(self, list, allowed_attrs, msg=None):
# Check these attrs only are returned in create response
for attr in list:
if attr not in allowed_attrs:
self.fail("%s should not contain '%s'" % (msg, attr))
def links(self, links):
allowed_attrs = ['href', 'rel']
for link in links:
self.contains_allowed_attrs(link, allowed_attrs, msg="Links")
class CollectionCheck(Check):
"""Checks for elements in a dictionary."""
def __init__(self, name, collection):
self.name = name
self.collection = collection
super(CollectionCheck, self).__init__()
def element_equals(self, key, expected_value):
if key not in self.collection:
message = 'Element "%s.%s" does not exist.' % (self.name, key)
self.fail(message)
else:
value = self.collection[key]
self.equal(value, expected_value)
def has_element(self, key, element_type):
if key not in self.collection:
message = 'Element "%s.%s" does not exist.' % (self.name, key)
self.fail(message)
else:
value = self.collection[key]
match = False
if not isinstance(element_type, tuple):
type_list = [element_type]
else:
type_list = element_type
for possible_type in type_list:
if possible_type is None:
if value is None:
match = True
else:
if isinstance(value, possible_type):
match = True
if not match:
self.fail('Element "%s.%s" does not match any of these '
'expected types: %s' % (self.name, key, type_list))
class TypeCheck(Check):
"""Checks for attributes in an object."""
def __init__(self, name, instance):
self.name = name
self.instance = instance
super(TypeCheck, self).__init__()
def _check_type(value, attribute_type):
if not isinstance(value, attribute_type):
self.fail("%s attribute %s is of type %s (expected %s)."
% (self.name, attribute_name, type(value),
attribute_type))
def has_field(self, attribute_name, attribute_type,
additional_checks=None):
if not hasattr(self.instance, attribute_name):
self.fail("%s missing attribute %s." % (self.name, attribute_name))
else:
value = getattr(self.instance, attribute_name)
match = False
if isinstance(attribute_type, tuple):
type_list = attribute_type
else:
type_list = [attribute_type]
for possible_type in type_list:
if possible_type is None:
if value is None:
match = True
else:
if isinstance(value, possible_type):
match = True
if not match:
self.fail("%s attribute %s is of type %s (expected one of "
"the following: %s)." % (self.name, attribute_name,
type(value),
attribute_type))
if match and additional_checks:
additional_checks(value)