mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
minor matmul location cleanup
This commit is contained in:
@@ -226,7 +226,8 @@ class Tensor:
|
||||
slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num))
|
||||
return [self.slice(arg=p) for p in slice_params]
|
||||
|
||||
def matmul(self:Tensor, w:Tensor):
|
||||
# TODO: what's the difference between dot and matmul?
|
||||
def dot(self:Tensor, w:Tensor):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
@@ -244,12 +245,6 @@ class Tensor:
|
||||
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
|
||||
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
|
||||
|
||||
# TODO: what's the difference between dot and matmul?
|
||||
def dot(self:Tensor, w:Tensor): return self.matmul(w)
|
||||
def __matmul__(self:Tensor, w:Tensor): return self.matmul(w)
|
||||
def __imatmul__(self:Tensor, w:Tensor): self.assign(self.matmul(w))
|
||||
def __rmatmul__(self:Tensor, w:Tensor): return w.matmul(self)
|
||||
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
def pad2d(self, padding:Tuple[int, ...]): return self.slice(arg = [(0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])])
|
||||
# TODO: this is totally not transpose
|
||||
@@ -358,6 +353,8 @@ class Tensor:
|
||||
def __rpow__(self, x): return Tensor.broadcasted(mlops.Pow, x, self)
|
||||
def __truediv__(self, x): return self * (x.reciprocal() if isinstance(x, Tensor) else (1/x))
|
||||
def __rtruediv__(self, x): return self.reciprocal() * x
|
||||
def __matmul__(self, x): return self.dot(x)
|
||||
def __rmatmul__(self, x:Tensor): return x.dot(self)
|
||||
|
||||
# assignment, any way to make this automatic?
|
||||
def __iadd__(self, x): return self.assign(self.__add__(x))
|
||||
@@ -365,6 +362,7 @@ class Tensor:
|
||||
def __imul__(self, x): return self.assign(self.__mul__(x))
|
||||
def __ipow__(self, x): return self.assign(self.__pow__(x))
|
||||
def __itruediv__(self, x): return self.assign(self.__truediv__(x))
|
||||
def __imatmul__(self, x): self.assign(self.__matmul__(x))
|
||||
|
||||
# simple tensor math API
|
||||
def add(self, x): return self.__add__(x)
|
||||
@@ -372,6 +370,7 @@ class Tensor:
|
||||
def mul(self, x): return self.__mul__(x)
|
||||
def pow(self, x): return self.__pow__(x)
|
||||
def div(self, x): return self.__truediv__(x)
|
||||
def matmul(self, x): return self.__matmul__(x)
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user