From 4f4ecbec976ffbcec30e35db681aa65962167825 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 6 Sep 2022 17:39:26 -0700 Subject: [PATCH] add div to operators --- tinygrad/tensor.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7c968f6b05..bed9c65d76 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -11,6 +11,7 @@ from tinygrad.ops import LazyBuffer # **** start with two base classes, Tensor and Function **** class Tensor: + # TODO: remove no_init when uniform is late bind training, no_grad, no_init = False, False, False def __init__(self, data, device=Device.DEFAULT, requires_grad=True): @@ -255,7 +256,6 @@ class Tensor: # ***** activation functions (unary) ***** def sigmoid(self): return (1.0 + (-self).exp()).reciprocal() - # TODO: implement generic constant folding def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu() def swish(self): return self * self.sigmoid() silu = swish # The SiLU function is also known as the swish function. @@ -284,7 +284,6 @@ class Tensor: def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x) def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x) def div(self, y): return self * (y.reciprocal() if isinstance(y, Tensor) else (1/y)) - __truediv__ = div # ***** functional nn ops ***** @@ -339,10 +338,9 @@ for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), i register(name.lower(), cls) # register the operators -# TODO: add div def register_op(name, fxn): setattr(Tensor, f"__{name}__", fxn) setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(fxn(self,x))) setattr(Tensor, f"__r{name}__", lambda self,x: fxn(x,self)) -for name in ['add', 'sub', 'mul', 'pow', 'matmul']: - register_op(name, getattr(Tensor, name)) \ No newline at end of file +for name in ['add', 'sub', 'mul', 'pow', 'matmul', 'truediv']: + register_op(name, getattr(Tensor, name if name != 'truediv' else 'div')) \ No newline at end of file