pow cleanup part 2 (#4727)

more cleanups and fix 0 ** 0
This commit is contained in:
chenyu
2024-05-25 07:17:40 -04:00
committed by GitHub
parent 85e57223bd
commit 7e90026eb0
2 changed files with 12 additions and 10 deletions

View File

@@ -419,8 +419,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: 2.0**x)
helper_test_op([()], lambda x: x**2.0)
helper_test_op([()], lambda x: 2.0**x)
# TODO: fix 0**x and 0**0 == 1
# helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]])
# TODO: fix backward
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
# TODO: fix backward, should be nan
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)

View File

@@ -2420,16 +2420,18 @@ class Tensor:
base, exponent = self._broadcasted(x, reverse=reverse)
ret = base.abs().log().mul(exponent).exp()
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the exponent)
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
sign = (exponent * math.pi).cos()
# we only need to correct the sign if the base is negative
base_sign = ((base.sign()) - 1) / -2
# we need 0 to be positive so we need to correct base_sign when the base is 0
base_sign = base_sign - (1.5 * (1 - (base.sign().abs())))
# inject nan if the base is negative and the exponent is not an integer
to_nan = (exponent != exponent.trunc()).detach() * base_sign
negative_base = (base < 0).detach().where(1, 0)
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
correct_sign = sign * negative_base + (1 - negative_base)
# inject nan for negative base is negative and non-integer exponent
to_nan = negative_base * (exponent != exponent.trunc()).detach()
# 0 -> 1; 1 -> nan
inject_nan = (-to_nan * 2 + 1).log().add(1)
return ret.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
ret = ret.mul(correct_sign).mul(inject_nan)
# fix 0 ** 0 = 1
return ((base == 0) * (exponent == 0)).detach().where(1, ret)
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
"""