use *0+1 for Tensor.pow base case, remove function.Zero (#4023)

one less mlops!
This commit is contained in:
chenyu
2024-03-31 19:20:44 -04:00
committed by GitHub
parent 276ef8eb87
commit 23c912e338
2 changed files with 1 additions and 5 deletions

View File

@@ -24,10 +24,6 @@ class Cast(Function):
# ************* unary ops *************
class Zero(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.const(0)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
class Neg(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)

View File

@@ -902,7 +902,7 @@ class Tensor:
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
if x < 0: return self.reciprocal().pow(-x)
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), F.Zero.apply(self)+1)
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), self*0+1)
if x == 0.5: return self.sqrt()
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()