mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
better order in Tensor class
This commit is contained in:
@@ -184,7 +184,15 @@ class Tensor:
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
del t0._ctx
|
||||
|
||||
# ***** non first class ops (hlops) *****
|
||||
# ***** movement mlops *****
|
||||
|
||||
def reshape(self, shape, *args): return mlops.Reshape.apply(self, shape=argfix(shape, *args))
|
||||
def expand(self, shape, *args): return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
|
||||
def permute(self, order, *args): return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args): return mlops.Flip.apply(self, axis=argfix(axis, *args))
|
||||
def slice(self, arg): return mlops.Slice.apply(self, arg=arg)
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
||||
# Tensors mostly follow the normal python indexing / slicing behavior for sequences
|
||||
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
||||
@@ -230,24 +238,9 @@ class Tensor:
|
||||
slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num))
|
||||
return [self.slice(arg=p) for p in slice_params]
|
||||
|
||||
# TODO: what's the difference between dot and matmul?
|
||||
def dot(self:Tensor, w:Tensor):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = tuple(list(self.shape[0:-2])+[cout,-1])
|
||||
if len(self.shape) > 1:
|
||||
order = tuple(list(range(len(self.shape)-2))+[len(self.shape)-1, len(self.shape)-2])
|
||||
else:
|
||||
order, out_shape_t = (0,), (cout, )
|
||||
worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2])
|
||||
|
||||
# NOTE: with NHWC we can remove the transposes
|
||||
# bs x groups*cin x H x W
|
||||
cx = self.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
|
||||
# groups*cout x cin x H, W
|
||||
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
|
||||
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
|
||||
def unsqueeze(self, dim):
|
||||
if dim < 0: dim = len(self.shape) + dim + 1
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
def pad2d(self, padding:Tuple[int, ...]): return self.slice(arg = [(0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])])
|
||||
@@ -255,6 +248,8 @@ class Tensor:
|
||||
def transpose(self, order=(1,0)): return self.permute(order=order)
|
||||
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
|
||||
|
||||
# ***** reduce ops *****
|
||||
|
||||
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False):
|
||||
axis_ : List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
|
||||
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
|
||||
@@ -284,11 +279,7 @@ class Tensor:
|
||||
m, _, ss = self._softmax()
|
||||
return m - ss.log()
|
||||
|
||||
def dropout(self, p=0.5) -> Tensor:
|
||||
if not Tensor.training:
|
||||
return self
|
||||
_mask : np.ndarray = np.asarray(Tensor._rng.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
|
||||
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
|
||||
# ***** processing ops *****
|
||||
|
||||
def _pool2d(self, ky, kx, sy, sx, dy=1, dx=1):
|
||||
if ky > sy or kx > sx or dy != 1 or dx != 1:
|
||||
@@ -333,6 +324,32 @@ class Tensor:
|
||||
ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
||||
|
||||
def dot(self:Tensor, w:Tensor):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = tuple(list(self.shape[0:-2])+[cout,-1])
|
||||
if len(self.shape) > 1:
|
||||
order = tuple(list(range(len(self.shape)-2))+[len(self.shape)-1, len(self.shape)-2])
|
||||
else:
|
||||
order, out_shape_t = (0,), (cout, )
|
||||
worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2])
|
||||
|
||||
# NOTE: with NHWC we can remove the transposes
|
||||
# bs x groups*cin x H x W
|
||||
cx = self.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
|
||||
# groups*cout x cin x H, W
|
||||
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
|
||||
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
|
||||
|
||||
# ***** mlops (unary) *****
|
||||
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
def relu(self): return mlops.ReLU.apply(self)
|
||||
def log(self): return mlops.Log.apply(self)
|
||||
def exp(self): return mlops.Exp.apply(self)
|
||||
def reciprocal(self): return mlops.Reciprocal.apply(self)
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
|
||||
def __neg__(self): return 0.0-self
|
||||
@@ -357,7 +374,7 @@ class Tensor:
|
||||
def mish(self): return self * self.softplus().tanh()
|
||||
def softplus(self, limit=20, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
|
||||
|
||||
# ***** broadcasted binary ops *****
|
||||
# ***** broadcasted binary mlops *****
|
||||
|
||||
@staticmethod
|
||||
def broadcasted(fxn:Type[Function], tx:Union[Tensor, float], ty:Union[Tensor, float]):
|
||||
@@ -367,14 +384,6 @@ class Tensor:
|
||||
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
|
||||
return fxn.apply(x.expand(shape_ret), y.expand(shape_ret))
|
||||
|
||||
# ***** first class ops (mlops) *****
|
||||
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
def relu(self): return mlops.ReLU.apply(self)
|
||||
def log(self): return mlops.Log.apply(self)
|
||||
def exp(self): return mlops.Exp.apply(self)
|
||||
def reciprocal(self): return mlops.Reciprocal.apply(self)
|
||||
|
||||
# NOTE: __pow__ and friends are broken in mypyc with the ** operator
|
||||
def __add__(self, x): return Tensor.broadcasted(mlops.Add, self, x)
|
||||
def __radd__(self, x): return Tensor.broadcasted(mlops.Add, x, self)
|
||||
@@ -385,6 +394,8 @@ class Tensor:
|
||||
def __pow__(self, x): return Tensor.broadcasted(mlops.Pow, self, x)
|
||||
def __rpow__(self, x): return Tensor.broadcasted(mlops.Pow, x, self)
|
||||
|
||||
# ***** arithmetic hlops and wrappers *****
|
||||
|
||||
# non broadcasted ops
|
||||
def __truediv__(self, x): return self * (x.reciprocal() if isinstance(x, Tensor) else (1/x))
|
||||
def __rtruediv__(self, x): return self.reciprocal() * x
|
||||
@@ -409,16 +420,6 @@ class Tensor:
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def reshape(self, shape, *args): return mlops.Reshape.apply(self, shape=argfix(shape, *args))
|
||||
def expand(self, shape, *args): return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
|
||||
def permute(self, order, *args): return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args): return mlops.Flip.apply(self, axis=argfix(axis, *args))
|
||||
def slice(self, arg): return mlops.Slice.apply(self, arg=arg)
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
if dim < 0: dim = len(self.shape) + dim + 1
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
|
||||
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) # type: ignore
|
||||
return x.add(bias) if bias is not None else x
|
||||
@@ -433,6 +434,11 @@ class Tensor:
|
||||
x = (self - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1])
|
||||
return x.mul(invstd.reshape(shape=[1, -1, 1, 1])) + bias.reshape(shape=[1, -1, 1, 1])
|
||||
|
||||
def dropout(self, p=0.5) -> Tensor:
|
||||
if not Tensor.training: return self
|
||||
_mask : np.ndarray = np.asarray(Tensor._rng.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
|
||||
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
|
||||
|
||||
# register functions to move between devices
|
||||
for device in [device for device in Device._buffers.keys() if device[0] != "_"]:
|
||||
setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
|
||||
|
||||
Reference in New Issue
Block a user