pow cleanup part 3 (#4731)

fast pow for int or (int+0.5) const exponent. and more comments
This commit is contained in:
chenyu
2024-05-25 15:48:52 -04:00
committed by GitHub
parent de5c69c4c9
commit 8415b14978

View File

@@ -2414,24 +2414,23 @@ class Tensor:
# simple pow identities
if x < 0: return self.reciprocal().pow(-x)
if x == 0: return 1 + self * 0
if x in [3,2,1]: return functools.reduce(lambda acc,_: acc * self, range(int(x)-1), self)
if x == 0.5: return self.sqrt()
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
# positive const ** self
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
base, exponent = self._broadcasted(x, reverse=reverse)
# start with b ** e = exp(e * log(b))
ret = base.abs().log().mul(exponent).exp()
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
sign = (exponent * math.pi).cos()
negative_base = (base < 0).detach().where(1, 0)
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
correct_sign = sign * negative_base + (1 - negative_base)
# inject nan for negative base is negative and non-integer exponent
to_nan = negative_base * (exponent != exponent.trunc()).detach()
# 0 -> 1; 1 -> nan
inject_nan = (-to_nan * 2 + 1).log().add(1)
ret = ret.mul(correct_sign).mul(inject_nan)
# fix 0 ** 0 = 1
return ((base == 0) * (exponent == 0)).detach().where(1, ret)
correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
# inject nan for negative base and non-integer exponent
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
"""