diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index d6b0ce3bc8..f35cb8cc56 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -26,7 +26,7 @@ class ReLU(_UnaryOp): class Log(_UnaryOp): fop = UnaryOps.LOG - bop = BinaryOps.DIV + bop = BinaryOps.DIV # TODO: flip order of DIV class Exp(_UnaryOp): def forward(ctx, input): @@ -36,6 +36,8 @@ class Exp(_UnaryOp): bop = BinaryOps.MUL +# TODO: add Neg? + # ************* reduce ops ************* class Sum(Function): @@ -96,6 +98,8 @@ class Mul(Function): grad_y = ctx.binary_op(BinaryOps.MUL, x, grad_output) if ctx.needs_input_grad[1] else None return grad_x, grad_y +# TODO: add Div? + class Pow(Function): def forward(ctx, x, y): ret = ctx.binary_op(BinaryOps.POW, x, y) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d9f0acb369..4829142246 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -41,6 +41,7 @@ class Tensor: return f"" def realize(self): + # TODO: once lazy is upstreamed, we can remove this check if getattr(self.data, 'realize', None) is not None: self.data.realize() @@ -65,6 +66,8 @@ class Tensor: # ***** creation helper functions ***** + # TODO: remove use of numpy here + @classmethod def zeros(cls, *shape, **kwargs): return cls(np.zeros(shape, dtype=np.float32), **kwargs) @@ -155,7 +158,7 @@ class Tensor: def numpy(self): return np.array(self.cpu().data) - # ***** non first class ops ***** + # ***** non first class ops (hlops) ***** def __getitem__(self, val): arg = [] @@ -213,6 +216,7 @@ class Tensor: cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1)) return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order) + # TODO: what's the difference between dot and matmul? dot = matmul def transpose(self, order=(1,0)): @@ -243,43 +247,6 @@ class Tensor: out = self.sum(axis=axis, keepdim=keepdim) return out * (prod(out.shape)/prod(self.shape)) - def sqrt(self): - return self.pow(0.5) - - def div(self, y): - return self * (y ** -1.0) - __truediv__ = div - - def sigmoid(self): - #e = self.exp(); return e.div(1 + e) - return (1.0 + (0.0-self).exp()) ** -1.0 - - def elu(self, alpha=1.0): - return self.relu() - (-alpha*(self.exp() - 1)).relu() - - def swish(self): - return self * self.sigmoid() - - def relu6(self): - return self.relu() - (self-6).relu() - - def clip(self, min, max): - return ((self-min).relu()+min) - (self-max).relu() - - def hardswish(self): - return self * (self+3).relu6() * (1/6) - - def tanh(self): - return 2.0 * ((2.0 * self).sigmoid()) - 1.0 - - def gelu(x): - # https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py - #import torch; return Tensor(torch.nn.functional.gelu(torch.tensor(x.data)).numpy()) - return 0.5 * x * (1 + (x * 0.7978845608 * (1 + 0.044715 * x * x)).tanh()) - - def leakyrelu(self, neg_slope=0.01): - return self.relu() - (-neg_slope*self).relu() - def _softmax(self): m = self - self.max(axis=len(self.shape)-1, keepdim=True) e = m.exp() @@ -298,21 +265,7 @@ class Tensor: _mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype) return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p)) - def softplus(self, limit=20, beta=1): - # safe softplus - 1/beta*log(1 + exp(beta*x)) (PyTorch) - eb = (self*beta).exp() - ret = (1 + eb).log() - return (1/beta)*ret - - def mish(self): - return self * (self.softplus().tanh()) # x*tanh(softplus(x)) - - def abs(self): - return self.relu() + (-1.0*self).relu() - - def sign(self): - return self / (self.abs() + 1e-10) - + # TODO: support arbitrary strides def _pool2d(self, py, px): xup = self[:, :, :self.shape[2]-self.shape[2]%py, :self.shape[3]-self.shape[3]%px] if (self.shape[2]%py != 0) or (self.shape[3]%px != 0) else self return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px)) @@ -327,6 +280,27 @@ class Tensor: ret = self._conv2d(weight, **kwargs) return ret if bias is None else ret.add(bias.reshape(shape=[1, -1, 1, 1])) + # ***** math functions (unary) ***** + + def sqrt(self): return self.pow(0.5) + def clip(self, min, max): return ((self-min).relu()+min) - (self-max).relu() + def abs(self): return self.relu() + (0.0-self).relu() + def sign(self): return self / (self.abs() + 1e-10) + + # ***** activation functions (unary) ***** + + # TODO: make "-self" work + def sigmoid(self): return (1.0 + (0.0-self).exp()) ** -1.0 + def elu(self, alpha=1.0): return self.relu() - (-alpha*(self.exp() - 1)).relu() + def swish(self): return self * self.sigmoid() + def relu6(self): return self.relu() - (self-6).relu() + def hardswish(self): return self * (self+3).relu6() * (1/6) + def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0 + def gelu(x): return 0.5 * x * (1 + (x * 0.7978845608 * (1 + 0.044715 * x * x)).tanh()) + def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu() + def mish(self): return self * self.softplus().tanh() + def softplus(self, limit=20, beta=1): return (1/beta) * (1 + (self*beta).exp()).log() + # ***** broadcasted binary ops ***** @staticmethod @@ -350,9 +324,14 @@ class Tensor: def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x) def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x) + # TODO: should be broadcasted binary op + def div(self, y): + return self * (y ** -1.0) + __truediv__ = div + # ***** functional nn ops ***** - # TODO: fix the kwargs problem + # TODO: fix the kwargs problem, then remove these def reshape(self, shape): return self._reshape(shape=shape) def expand(self, shape): return self._expand(shape=shape) @@ -390,8 +369,7 @@ class Function(Ops): ctx = cls(x[0].device, *x) ret = Tensor(ctx.forward(*[t.data for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad) - if ctx.requires_grad: - ret._ctx = ctx # used by autograd engine + if ctx.requires_grad: ret._ctx = ctx # used by autograd engine return ret # register functions to move between devices