better order in Tensor class

This commit is contained in:
George Hotz
2023-02-23 19:33:37 -08:00
parent f2ca81c66d
commit 10c6ccf7e0

View File

@@ -184,7 +184,15 @@ class Tensor:
t.grad = g if t.grad is None else (t.grad + g)
del t0._ctx
# ***** non first class ops (hlops) *****
# ***** movement mlops *****
def reshape(self, shape, *args): return mlops.Reshape.apply(self, shape=argfix(shape, *args))
def expand(self, shape, *args): return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
def permute(self, order, *args): return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args): return mlops.Flip.apply(self, axis=argfix(axis, *args))
def slice(self, arg): return mlops.Slice.apply(self, arg=arg)
# ***** movement hlops *****
# Tensors mostly follow the normal python indexing / slicing behavior for sequences
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
@@ -230,24 +238,9 @@ class Tensor:
slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num))
return [self.slice(arg=p) for p in slice_params]
# TODO: what's the difference between dot and matmul?
def dot(self:Tensor, w:Tensor):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = tuple(list(self.shape[0:-2])+[cout,-1])
if len(self.shape) > 1:
order = tuple(list(range(len(self.shape)-2))+[len(self.shape)-1, len(self.shape)-2])
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2])
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
def unsqueeze(self, dim):
if dim < 0: dim = len(self.shape) + dim + 1
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Tuple[int, ...]): return self.slice(arg = [(0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])])
@@ -255,6 +248,8 @@ class Tensor:
def transpose(self, order=(1,0)): return self.permute(order=order)
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
# ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False):
axis_ : List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
@@ -284,11 +279,7 @@ class Tensor:
m, _, ss = self._softmax()
return m - ss.log()
def dropout(self, p=0.5) -> Tensor:
if not Tensor.training:
return self
_mask : np.ndarray = np.asarray(Tensor._rng.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
# ***** processing ops *****
def _pool2d(self, ky, kx, sy, sx, dy=1, dx=1):
if ky > sy or kx > sx or dy != 1 or dx != 1:
@@ -333,6 +324,32 @@ class Tensor:
ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
def dot(self:Tensor, w:Tensor):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = tuple(list(self.shape[0:-2])+[cout,-1])
if len(self.shape) > 1:
order = tuple(list(range(len(self.shape)-2))+[len(self.shape)-1, len(self.shape)-2])
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2])
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
# ***** mlops (unary) *****
def contiguous(self): return mlops.Contiguous.apply(self)
def relu(self): return mlops.ReLU.apply(self)
def log(self): return mlops.Log.apply(self)
def exp(self): return mlops.Exp.apply(self)
def reciprocal(self): return mlops.Reciprocal.apply(self)
# ***** math functions (unary) *****
def __neg__(self): return 0.0-self
@@ -357,7 +374,7 @@ class Tensor:
def mish(self): return self * self.softplus().tanh()
def softplus(self, limit=20, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
# ***** broadcasted binary ops *****
# ***** broadcasted binary mlops *****
@staticmethod
def broadcasted(fxn:Type[Function], tx:Union[Tensor, float], ty:Union[Tensor, float]):
@@ -367,14 +384,6 @@ class Tensor:
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
return fxn.apply(x.expand(shape_ret), y.expand(shape_ret))
# ***** first class ops (mlops) *****
def contiguous(self): return mlops.Contiguous.apply(self)
def relu(self): return mlops.ReLU.apply(self)
def log(self): return mlops.Log.apply(self)
def exp(self): return mlops.Exp.apply(self)
def reciprocal(self): return mlops.Reciprocal.apply(self)
# NOTE: __pow__ and friends are broken in mypyc with the ** operator
def __add__(self, x): return Tensor.broadcasted(mlops.Add, self, x)
def __radd__(self, x): return Tensor.broadcasted(mlops.Add, x, self)
@@ -385,6 +394,8 @@ class Tensor:
def __pow__(self, x): return Tensor.broadcasted(mlops.Pow, self, x)
def __rpow__(self, x): return Tensor.broadcasted(mlops.Pow, x, self)
# ***** arithmetic hlops and wrappers *****
# non broadcasted ops
def __truediv__(self, x): return self * (x.reciprocal() if isinstance(x, Tensor) else (1/x))
def __rtruediv__(self, x): return self.reciprocal() * x
@@ -409,16 +420,6 @@ class Tensor:
# ***** functional nn ops *****
def reshape(self, shape, *args): return mlops.Reshape.apply(self, shape=argfix(shape, *args))
def expand(self, shape, *args): return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
def permute(self, order, *args): return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args): return mlops.Flip.apply(self, axis=argfix(axis, *args))
def slice(self, arg): return mlops.Slice.apply(self, arg=arg)
def unsqueeze(self, dim):
if dim < 0: dim = len(self.shape) + dim + 1
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) # type: ignore
return x.add(bias) if bias is not None else x
@@ -433,6 +434,11 @@ class Tensor:
x = (self - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1])
return x.mul(invstd.reshape(shape=[1, -1, 1, 1])) + bias.reshape(shape=[1, -1, 1, 1])
def dropout(self, p=0.5) -> Tensor:
if not Tensor.training: return self
_mask : np.ndarray = np.asarray(Tensor._rng.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
# register functions to move between devices
for device in [device for device in Device._buffers.keys() if device[0] != "_"]:
setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))