mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
cleanups to Tensor class
This commit is contained in:
@@ -26,7 +26,7 @@ class ReLU(_UnaryOp):
|
||||
|
||||
class Log(_UnaryOp):
|
||||
fop = UnaryOps.LOG
|
||||
bop = BinaryOps.DIV
|
||||
bop = BinaryOps.DIV # TODO: flip order of DIV
|
||||
|
||||
class Exp(_UnaryOp):
|
||||
def forward(ctx, input):
|
||||
@@ -36,6 +36,8 @@ class Exp(_UnaryOp):
|
||||
|
||||
bop = BinaryOps.MUL
|
||||
|
||||
# TODO: add Neg?
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
@@ -96,6 +98,8 @@ class Mul(Function):
|
||||
grad_y = ctx.binary_op(BinaryOps.MUL, x, grad_output) if ctx.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
|
||||
# TODO: add Div?
|
||||
|
||||
class Pow(Function):
|
||||
def forward(ctx, x, y):
|
||||
ret = ctx.binary_op(BinaryOps.POW, x, y)
|
||||
|
||||
@@ -41,6 +41,7 @@ class Tensor:
|
||||
return f"<Tensor {self.data!r} with grad {(self.grad.data if self.grad else None)!r}>"
|
||||
|
||||
def realize(self):
|
||||
# TODO: once lazy is upstreamed, we can remove this check
|
||||
if getattr(self.data, 'realize', None) is not None:
|
||||
self.data.realize()
|
||||
|
||||
@@ -65,6 +66,8 @@ class Tensor:
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
# TODO: remove use of numpy here
|
||||
|
||||
@classmethod
|
||||
def zeros(cls, *shape, **kwargs):
|
||||
return cls(np.zeros(shape, dtype=np.float32), **kwargs)
|
||||
@@ -155,7 +158,7 @@ class Tensor:
|
||||
def numpy(self):
|
||||
return np.array(self.cpu().data)
|
||||
|
||||
# ***** non first class ops *****
|
||||
# ***** non first class ops (hlops) *****
|
||||
|
||||
def __getitem__(self, val):
|
||||
arg = []
|
||||
@@ -213,6 +216,7 @@ class Tensor:
|
||||
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
|
||||
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
|
||||
|
||||
# TODO: what's the difference between dot and matmul?
|
||||
dot = matmul
|
||||
|
||||
def transpose(self, order=(1,0)):
|
||||
@@ -243,43 +247,6 @@ class Tensor:
|
||||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
return out * (prod(out.shape)/prod(self.shape))
|
||||
|
||||
def sqrt(self):
|
||||
return self.pow(0.5)
|
||||
|
||||
def div(self, y):
|
||||
return self * (y ** -1.0)
|
||||
__truediv__ = div
|
||||
|
||||
def sigmoid(self):
|
||||
#e = self.exp(); return e.div(1 + e)
|
||||
return (1.0 + (0.0-self).exp()) ** -1.0
|
||||
|
||||
def elu(self, alpha=1.0):
|
||||
return self.relu() - (-alpha*(self.exp() - 1)).relu()
|
||||
|
||||
def swish(self):
|
||||
return self * self.sigmoid()
|
||||
|
||||
def relu6(self):
|
||||
return self.relu() - (self-6).relu()
|
||||
|
||||
def clip(self, min, max):
|
||||
return ((self-min).relu()+min) - (self-max).relu()
|
||||
|
||||
def hardswish(self):
|
||||
return self * (self+3).relu6() * (1/6)
|
||||
|
||||
def tanh(self):
|
||||
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
||||
|
||||
def gelu(x):
|
||||
# https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py
|
||||
#import torch; return Tensor(torch.nn.functional.gelu(torch.tensor(x.data)).numpy())
|
||||
return 0.5 * x * (1 + (x * 0.7978845608 * (1 + 0.044715 * x * x)).tanh())
|
||||
|
||||
def leakyrelu(self, neg_slope=0.01):
|
||||
return self.relu() - (-neg_slope*self).relu()
|
||||
|
||||
def _softmax(self):
|
||||
m = self - self.max(axis=len(self.shape)-1, keepdim=True)
|
||||
e = m.exp()
|
||||
@@ -298,21 +265,7 @@ class Tensor:
|
||||
_mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
|
||||
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
|
||||
|
||||
def softplus(self, limit=20, beta=1):
|
||||
# safe softplus - 1/beta*log(1 + exp(beta*x)) (PyTorch)
|
||||
eb = (self*beta).exp()
|
||||
ret = (1 + eb).log()
|
||||
return (1/beta)*ret
|
||||
|
||||
def mish(self):
|
||||
return self * (self.softplus().tanh()) # x*tanh(softplus(x))
|
||||
|
||||
def abs(self):
|
||||
return self.relu() + (-1.0*self).relu()
|
||||
|
||||
def sign(self):
|
||||
return self / (self.abs() + 1e-10)
|
||||
|
||||
# TODO: support arbitrary strides
|
||||
def _pool2d(self, py, px):
|
||||
xup = self[:, :, :self.shape[2]-self.shape[2]%py, :self.shape[3]-self.shape[3]%px] if (self.shape[2]%py != 0) or (self.shape[3]%px != 0) else self
|
||||
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px))
|
||||
@@ -327,6 +280,27 @@ class Tensor:
|
||||
ret = self._conv2d(weight, **kwargs)
|
||||
return ret if bias is None else ret.add(bias.reshape(shape=[1, -1, 1, 1]))
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
|
||||
def sqrt(self): return self.pow(0.5)
|
||||
def clip(self, min, max): return ((self-min).relu()+min) - (self-max).relu()
|
||||
def abs(self): return self.relu() + (0.0-self).relu()
|
||||
def sign(self): return self / (self.abs() + 1e-10)
|
||||
|
||||
# ***** activation functions (unary) *****
|
||||
|
||||
# TODO: make "-self" work
|
||||
def sigmoid(self): return (1.0 + (0.0-self).exp()) ** -1.0
|
||||
def elu(self, alpha=1.0): return self.relu() - (-alpha*(self.exp() - 1)).relu()
|
||||
def swish(self): return self * self.sigmoid()
|
||||
def relu6(self): return self.relu() - (self-6).relu()
|
||||
def hardswish(self): return self * (self+3).relu6() * (1/6)
|
||||
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
||||
def gelu(x): return 0.5 * x * (1 + (x * 0.7978845608 * (1 + 0.044715 * x * x)).tanh())
|
||||
def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
|
||||
def mish(self): return self * self.softplus().tanh()
|
||||
def softplus(self, limit=20, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
|
||||
|
||||
# ***** broadcasted binary ops *****
|
||||
|
||||
@staticmethod
|
||||
@@ -350,9 +324,14 @@ class Tensor:
|
||||
def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x)
|
||||
def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x)
|
||||
|
||||
# TODO: should be broadcasted binary op
|
||||
def div(self, y):
|
||||
return self * (y ** -1.0)
|
||||
__truediv__ = div
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
# TODO: fix the kwargs problem
|
||||
# TODO: fix the kwargs problem, then remove these
|
||||
def reshape(self, shape): return self._reshape(shape=shape)
|
||||
def expand(self, shape): return self._expand(shape=shape)
|
||||
|
||||
@@ -390,8 +369,7 @@ class Function(Ops):
|
||||
ctx = cls(x[0].device, *x)
|
||||
ret = Tensor(ctx.forward(*[t.data for t in x], **kwargs),
|
||||
device=ctx.device, requires_grad=ctx.requires_grad)
|
||||
if ctx.requires_grad:
|
||||
ret._ctx = ctx # used by autograd engine
|
||||
if ctx.requires_grad: ret._ctx = ctx # used by autograd engine
|
||||
return ret
|
||||
|
||||
# register functions to move between devices
|
||||
|
||||
Reference in New Issue
Block a user