Allow array exponents (with len > 1) when base is dimensionless

Close #483
This commit is contained in:
Hernan Grecco 2017-04-15 00:20:29 -03:00
parent e64ebc2e08
commit a82a6331af
3 changed files with 37 additions and 8 deletions

View File

@ -893,9 +893,21 @@ class _Quantity(SharedRegistryObject):
if isinstance(getattr(other, '_magnitude', other), ndarray):
# arrays are refused as exponent, because they would create
# len(array) quanitites of len(set(array)) different units
if np.size(other) > 1:
raise DimensionalityError(self._units, 'dimensionless')
# len(array) quantities of len(set(array)) different units
# unless the base is dimensionless.
if self.dimensionless:
if getattr(other, 'dimensionless', False):
self._magnitude **= other.m_as('')
return self
elif not getattr(other, 'dimensionless', True):
raise DimensionalityError(other._units, 'dimensionless')
else:
self._magnitude **= other
return self
elif np.size(other) > 1:
raise DimensionalityError(self._units, 'dimensionless',
extra_msg='Quantity array exponents are only allowed '
'if the base is dimensionless')
if other == 1:
return self
@ -930,9 +942,19 @@ class _Quantity(SharedRegistryObject):
if isinstance(getattr(other, '_magnitude', other), ndarray):
# arrays are refused as exponent, because they would create
# len(array) quantities of len(set(array)) different units
if np.size(other) > 1:
raise DimensionalityError(self._units, 'dimensionless')
# len(array) quantities of len(set(array)) different units
# unless the base is dimensionless.
if self.dimensionless:
if getattr(other, 'dimensionless', False):
return self.__class__(self.m ** other.m_as(''))
elif not getattr(other, 'dimensionless', True):
raise DimensionalityError(other._units, 'dimensionless')
else:
return self.__class__(self.m ** other)
elif np.size(other) > 1:
raise DimensionalityError(self._units, 'dimensionless',
extra_msg='Quantity array exponents are only allowed '
'if the base is dimensionless')
new_self = self
if other == 1:

View File

@ -546,4 +546,11 @@ class TestIssuesNP(QuantityTestCase):
x = ureg.Quantity(1., 'meter')
y = f(x)
z = x * y
self.assertEquals(z, ureg.Quantity(1., 'meter * kilogram'))
self.assertEquals(z, ureg.Quantity(1., 'meter * kilogram'))
def test_issue483(self):
ureg = self.ureg
a = np.asarray([1, 2, 3])
q = [1, 2, 3] * ureg.dimensionless
p = (q ** q).m
np.testing.assert_array_equal(p, a ** a)

View File

@ -427,7 +427,7 @@ class TestNDArrayQunatityMath(QuantityTestCase):
@helpers.requires_numpy()
def test_exponentiation_array_exp(self):
arr = np.array(range(3), dtype=np.float)
q = self.Q_(arr, None)
q = self.Q_(arr, 'meter')
for op_ in [op.pow, op.ipow]:
q_cp = copy.copy(q)