Merge branch 'develop' of git://github.com/kaidokert/pint
This commit is contained in:
commit
5218647d54
|
@ -147,5 +147,23 @@ To avoid the conversion of an argument or return value, use None
|
|||
>>> mypp3 = ureg.wraps((ureg.second, None), ureg.meter)(pendulum_period_error)
|
||||
|
||||
|
||||
Checking units
|
||||
==============
|
||||
|
||||
When you want pint quantities to be used as inputs to your functions, pint provides a wrapper to ensure units are of
|
||||
correct type - or more precisely, they match the expected dimensionality of the physical quantity.
|
||||
|
||||
Similar to wraps(), you can pass None to skip checking of some parameters, but the return parameter type is not checked.
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> mypp = ureg.check('[length]')(pendulum_period)
|
||||
|
||||
In the decorator format:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>>@ureg.check('[length]')
|
||||
... def pendulum_period(length):
|
||||
... return 2*math.pi*math.sqrt(length/G)
|
||||
|
||||
|
|
|
@ -75,6 +75,11 @@ try:
|
|||
except ImportError:
|
||||
from .nullhandler import NullHandler
|
||||
|
||||
try:
|
||||
from itertools import zip_longest
|
||||
except ImportError:
|
||||
from itertools import izip_longest as zip_longest
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
from numpy import ndarray
|
||||
|
|
|
@ -335,6 +335,45 @@ class TestRegistry(QuantityTestCase):
|
|||
h2 = ureg.wraps(('meter', 'cm'), [None, None])(hfunc)
|
||||
self.assertEqual(h2(3, 1), (3 * ureg.meter, 1 * ureg.cm))
|
||||
|
||||
def test_check(self):
|
||||
def func(x):
|
||||
return x
|
||||
|
||||
ureg = self.ureg
|
||||
|
||||
f0 = ureg.check('[length]')(func)
|
||||
self.assertRaises(AttributeError, f0, 3.)
|
||||
self.assertEqual(f0(3. * ureg.centimeter), 0.03 * ureg.meter)
|
||||
self.assertRaises(DimensionalityError, f0, 3. * ureg.kilogram)
|
||||
|
||||
f0b = ureg.check(ureg.meter)(func)
|
||||
self.assertRaises(AttributeError, f0b, 3.)
|
||||
self.assertEqual(f0b(3. * ureg.centimeter), 0.03 * ureg.meter)
|
||||
self.assertRaises(DimensionalityError, f0b, 3. * ureg.kilogram)
|
||||
|
||||
def gfunc(x, y):
|
||||
return x / y
|
||||
|
||||
g0 = ureg.check(None, None)(gfunc)
|
||||
self.assertEqual(g0(6, 2), 3)
|
||||
self.assertEqual(g0(6 * ureg.parsec, 2), 3 * ureg.parsec)
|
||||
|
||||
g1 = ureg.check('[speed]', '[time]')(gfunc)
|
||||
self.assertRaises(AttributeError, g1, 3.0, 1)
|
||||
self.assertRaises(DimensionalityError, g1, 1 * ureg.parsec, 1 * ureg.angstrom)
|
||||
self.assertRaises(TypeError, g1, 1 * ureg.km / ureg.hour, 1 * ureg.hour, 3.0)
|
||||
self.assertEqual(g1(3.6 * ureg.km / ureg.hour, 1 * ureg.second), 1 * ureg.meter / ureg.second ** 2)
|
||||
|
||||
g2 = ureg.check('[speed]')(gfunc)
|
||||
self.assertRaises(AttributeError, g2, 3.0, 1)
|
||||
self.assertRaises(DimensionalityError, g2, 2 * ureg.parsec)
|
||||
self.assertRaises(DimensionalityError, g2, 2 * ureg.parsec, 1.0)
|
||||
self.assertEqual(g2(2.0 * ureg.km / ureg.hour, 2), 1 * ureg.km / ureg.hour)
|
||||
|
||||
g3 = ureg.check('[speed]', '[time]', '[mass]')(gfunc)
|
||||
self.assertRaises(DimensionalityError, g3, 1 * ureg.parsec, 1 * ureg.angstrom)
|
||||
self.assertRaises(DimensionalityError, g3, 1 * ureg.parsec, 1 * ureg.angstrom, 1 * ureg.kilogram)
|
||||
|
||||
def test_to_ref_vs_to(self):
|
||||
self.ureg.autoconvert_offset_to_baseunit = True
|
||||
q = 8. * self.ureg.inch
|
||||
|
|
31
pint/unit.py
31
pint/unit.py
|
@ -29,7 +29,7 @@ from .util import (logger, pi_theorem, solve_dependencies, ParserHelper,
|
|||
string_preprocessor, find_connected_nodes,
|
||||
find_shortest_path, UnitsContainer, _is_dim,
|
||||
SharedRegistryObject, to_units_container)
|
||||
from .compat import tokenizer, string_types, NUMERIC_TYPES, long_type
|
||||
from .compat import tokenizer, string_types, NUMERIC_TYPES, long_type, zip_longest
|
||||
from .definitions import (Definition, UnitDefinition, PrefixDefinition,
|
||||
DimensionDefinition)
|
||||
from .converters import ScaleConverter
|
||||
|
@ -1145,6 +1145,35 @@ class UnitRegistry(object):
|
|||
return wrapper
|
||||
return decorator
|
||||
|
||||
def check(self, *args):
|
||||
"""Decorator to for quantity type checking for function inputs.
|
||||
|
||||
Use it to ensure that the decorated function input parameters match
|
||||
the expected type of pint quantity.
|
||||
|
||||
Use None to skip argument checking.
|
||||
|
||||
:param args: iterable of input units.
|
||||
:return: the wrapped function.
|
||||
:raises:
|
||||
:class:`TypeError` if the parameters don't match dimensions
|
||||
"""
|
||||
dimensions = [self.get_dimensionality(dim) for dim in args]
|
||||
|
||||
def decorator(func):
|
||||
assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
|
||||
updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))
|
||||
|
||||
@functools.wraps(func, assigned=assigned, updated=updated)
|
||||
def wrapper(*values, **kwargs):
|
||||
for dim, value in zip_longest(dimensions, values):
|
||||
if dim and value.dimensionality != dim:
|
||||
raise DimensionalityError(value, 'a quantity of',
|
||||
value.dimensionality, dim)
|
||||
return func(*values, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def build_unit_class(registry):
|
||||
|
||||
|
|
Loading…
Reference in New Issue