mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
pow cleanup part 3 (#4731)
fast pow for int or (int+0.5) const exponent. and more comments
This commit is contained in:
@@ -2414,24 +2414,23 @@ class Tensor:
|
||||
# simple pow identities
|
||||
if x < 0: return self.reciprocal().pow(-x)
|
||||
if x == 0: return 1 + self * 0
|
||||
if x in [3,2,1]: return functools.reduce(lambda acc,_: acc * self, range(int(x)-1), self)
|
||||
if x == 0.5: return self.sqrt()
|
||||
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
|
||||
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
|
||||
|
||||
# positive const ** self
|
||||
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
||||
|
||||
base, exponent = self._broadcasted(x, reverse=reverse)
|
||||
# start with b ** e = exp(e * log(b))
|
||||
ret = base.abs().log().mul(exponent).exp()
|
||||
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
|
||||
sign = (exponent * math.pi).cos()
|
||||
negative_base = (base < 0).detach().where(1, 0)
|
||||
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
|
||||
correct_sign = sign * negative_base + (1 - negative_base)
|
||||
# inject nan for negative base is negative and non-integer exponent
|
||||
to_nan = negative_base * (exponent != exponent.trunc()).detach()
|
||||
# 0 -> 1; 1 -> nan
|
||||
inject_nan = (-to_nan * 2 + 1).log().add(1)
|
||||
ret = ret.mul(correct_sign).mul(inject_nan)
|
||||
# fix 0 ** 0 = 1
|
||||
return ((base == 0) * (exponent == 0)).detach().where(1, ret)
|
||||
correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
|
||||
# inject nan for negative base and non-integer exponent
|
||||
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
|
||||
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
|
||||
return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
|
||||
|
||||
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user