Merge branch 'develop' of git://github.com/kaidokert/pint

This commit is contained in:
Hernan Grecco 2015-07-21 15:42:47 -03:00
commit 5218647d54
4 changed files with 92 additions and 1 deletions

View File

@ -147,5 +147,23 @@ To avoid the conversion of an argument or return value, use None
>>> mypp3 = ureg.wraps((ureg.second, None), ureg.meter)(pendulum_period_error)
Checking units
==============
When you want pint quantities to be used as inputs to your functions, pint provides a wrapper to ensure units are of
correct type - or more precisely, they match the expected dimensionality of the physical quantity.
Similar to wraps(), you can pass None to skip checking of some parameters, but the return parameter type is not checked.
.. doctest::
>>> mypp = ureg.check('[length]')(pendulum_period)
In the decorator format:
.. doctest::
>>>@ureg.check('[length]')
... def pendulum_period(length):
... return 2*math.pi*math.sqrt(length/G)

View File

@ -75,6 +75,11 @@ try:
except ImportError:
from .nullhandler import NullHandler
try:
from itertools import zip_longest
except ImportError:
from itertools import izip_longest as zip_longest
try:
import numpy as np
from numpy import ndarray

View File

@ -335,6 +335,45 @@ class TestRegistry(QuantityTestCase):
h2 = ureg.wraps(('meter', 'cm'), [None, None])(hfunc)
self.assertEqual(h2(3, 1), (3 * ureg.meter, 1 * ureg.cm))
def test_check(self):
def func(x):
return x
ureg = self.ureg
f0 = ureg.check('[length]')(func)
self.assertRaises(AttributeError, f0, 3.)
self.assertEqual(f0(3. * ureg.centimeter), 0.03 * ureg.meter)
self.assertRaises(DimensionalityError, f0, 3. * ureg.kilogram)
f0b = ureg.check(ureg.meter)(func)
self.assertRaises(AttributeError, f0b, 3.)
self.assertEqual(f0b(3. * ureg.centimeter), 0.03 * ureg.meter)
self.assertRaises(DimensionalityError, f0b, 3. * ureg.kilogram)
def gfunc(x, y):
return x / y
g0 = ureg.check(None, None)(gfunc)
self.assertEqual(g0(6, 2), 3)
self.assertEqual(g0(6 * ureg.parsec, 2), 3 * ureg.parsec)
g1 = ureg.check('[speed]', '[time]')(gfunc)
self.assertRaises(AttributeError, g1, 3.0, 1)
self.assertRaises(DimensionalityError, g1, 1 * ureg.parsec, 1 * ureg.angstrom)
self.assertRaises(TypeError, g1, 1 * ureg.km / ureg.hour, 1 * ureg.hour, 3.0)
self.assertEqual(g1(3.6 * ureg.km / ureg.hour, 1 * ureg.second), 1 * ureg.meter / ureg.second ** 2)
g2 = ureg.check('[speed]')(gfunc)
self.assertRaises(AttributeError, g2, 3.0, 1)
self.assertRaises(DimensionalityError, g2, 2 * ureg.parsec)
self.assertRaises(DimensionalityError, g2, 2 * ureg.parsec, 1.0)
self.assertEqual(g2(2.0 * ureg.km / ureg.hour, 2), 1 * ureg.km / ureg.hour)
g3 = ureg.check('[speed]', '[time]', '[mass]')(gfunc)
self.assertRaises(DimensionalityError, g3, 1 * ureg.parsec, 1 * ureg.angstrom)
self.assertRaises(DimensionalityError, g3, 1 * ureg.parsec, 1 * ureg.angstrom, 1 * ureg.kilogram)
def test_to_ref_vs_to(self):
self.ureg.autoconvert_offset_to_baseunit = True
q = 8. * self.ureg.inch

View File

@ -29,7 +29,7 @@ from .util import (logger, pi_theorem, solve_dependencies, ParserHelper,
string_preprocessor, find_connected_nodes,
find_shortest_path, UnitsContainer, _is_dim,
SharedRegistryObject, to_units_container)
from .compat import tokenizer, string_types, NUMERIC_TYPES, long_type
from .compat import tokenizer, string_types, NUMERIC_TYPES, long_type, zip_longest
from .definitions import (Definition, UnitDefinition, PrefixDefinition,
DimensionDefinition)
from .converters import ScaleConverter
@ -1145,6 +1145,35 @@ class UnitRegistry(object):
return wrapper
return decorator
def check(self, *args):
"""Decorator to for quantity type checking for function inputs.
Use it to ensure that the decorated function input parameters match
the expected type of pint quantity.
Use None to skip argument checking.
:param args: iterable of input units.
:return: the wrapped function.
:raises:
:class:`TypeError` if the parameters don't match dimensions
"""
dimensions = [self.get_dimensionality(dim) for dim in args]
def decorator(func):
assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))
@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*values, **kwargs):
for dim, value in zip_longest(dimensions, values):
if dim and value.dimensionality != dim:
raise DimensionalityError(value, 'a quantity of',
value.dimensionality, dim)
return func(*values, **kwargs)
return wrapper
return decorator
def build_unit_class(registry):