mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
use *0+1 for Tensor.pow base case, remove function.Zero (#4023)
one less mlops!
This commit is contained in:
@@ -24,10 +24,6 @@ class Cast(Function):
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
class Zero(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.const(0)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
|
||||
|
||||
class Neg(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
|
||||
|
||||
@@ -902,7 +902,7 @@ class Tensor:
|
||||
if not isinstance(x, Tensor) and not reverse:
|
||||
# simple pow identities
|
||||
if x < 0: return self.reciprocal().pow(-x)
|
||||
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), F.Zero.apply(self)+1)
|
||||
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), self*0+1)
|
||||
if x == 0.5: return self.sqrt()
|
||||
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
||||
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
|
||||
|
||||
Reference in New Issue
Block a user