mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -419,8 +419,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: 2.0**x)
|
||||
helper_test_op([()], lambda x: x**2.0)
|
||||
helper_test_op([()], lambda x: 2.0**x)
|
||||
# TODO: fix 0**x and 0**0 == 1
|
||||
# helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]])
|
||||
# TODO: fix backward
|
||||
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
|
||||
# TODO: fix backward, should be nan
|
||||
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
|
||||
|
||||
|
||||
@@ -2420,16 +2420,18 @@ class Tensor:
|
||||
|
||||
base, exponent = self._broadcasted(x, reverse=reverse)
|
||||
ret = base.abs().log().mul(exponent).exp()
|
||||
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the exponent)
|
||||
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
|
||||
sign = (exponent * math.pi).cos()
|
||||
# we only need to correct the sign if the base is negative
|
||||
base_sign = ((base.sign()) - 1) / -2
|
||||
# we need 0 to be positive so we need to correct base_sign when the base is 0
|
||||
base_sign = base_sign - (1.5 * (1 - (base.sign().abs())))
|
||||
# inject nan if the base is negative and the exponent is not an integer
|
||||
to_nan = (exponent != exponent.trunc()).detach() * base_sign
|
||||
negative_base = (base < 0).detach().where(1, 0)
|
||||
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
|
||||
correct_sign = sign * negative_base + (1 - negative_base)
|
||||
# inject nan for negative base is negative and non-integer exponent
|
||||
to_nan = negative_base * (exponent != exponent.trunc()).detach()
|
||||
# 0 -> 1; 1 -> nan
|
||||
inject_nan = (-to_nan * 2 + 1).log().add(1)
|
||||
return ret.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
|
||||
ret = ret.mul(correct_sign).mul(inject_nan)
|
||||
# fix 0 ** 0 = 1
|
||||
return ((base == 0) * (exponent == 0)).detach().where(1, ret)
|
||||
|
||||
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user