mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
add div to operators
This commit is contained in:
@@ -11,6 +11,7 @@ from tinygrad.ops import LazyBuffer
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Tensor:
|
||||
# TODO: remove no_init when uniform is late bind
|
||||
training, no_grad, no_init = False, False, False
|
||||
|
||||
def __init__(self, data, device=Device.DEFAULT, requires_grad=True):
|
||||
@@ -255,7 +256,6 @@ class Tensor:
|
||||
# ***** activation functions (unary) *****
|
||||
|
||||
def sigmoid(self): return (1.0 + (-self).exp()).reciprocal()
|
||||
# TODO: implement generic constant folding
|
||||
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
|
||||
def swish(self): return self * self.sigmoid()
|
||||
silu = swish # The SiLU function is also known as the swish function.
|
||||
@@ -284,7 +284,6 @@ class Tensor:
|
||||
def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x)
|
||||
def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x)
|
||||
def div(self, y): return self * (y.reciprocal() if isinstance(y, Tensor) else (1/y))
|
||||
__truediv__ = div
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
@@ -339,10 +338,9 @@ for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), i
|
||||
register(name.lower(), cls)
|
||||
|
||||
# register the operators
|
||||
# TODO: add div
|
||||
def register_op(name, fxn):
|
||||
setattr(Tensor, f"__{name}__", fxn)
|
||||
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(fxn(self,x)))
|
||||
setattr(Tensor, f"__r{name}__", lambda self,x: fxn(x,self))
|
||||
for name in ['add', 'sub', 'mul', 'pow', 'matmul']:
|
||||
register_op(name, getattr(Tensor, name))
|
||||
for name in ['add', 'sub', 'mul', 'pow', 'matmul', 'truediv']:
|
||||
register_op(name, getattr(Tensor, name if name != 'truediv' else 'div'))
|
||||
Reference in New Issue
Block a user