mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Dot -> Matmul
This commit is contained in:
@@ -111,7 +111,7 @@ You need to support 14 basic ops:
|
||||
Relu, Log, Exp # unary ops
|
||||
Add, Sub, Mul, Pow # binary ops (with broadcasting)
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Dot, Conv2D # matrix multiplication and conv
|
||||
Matmul, Conv2D # matrix multiplication and conv
|
||||
Reshape, Transpose, Slice # moving things around ops
|
||||
```
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ register('max', Max)
|
||||
|
||||
# ************* GEMM *************
|
||||
|
||||
class Dot(Function):
|
||||
class Matmul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
@@ -101,7 +101,7 @@ class Dot(Function):
|
||||
grad_input = grad_output @ np.swapaxes(weight, -2, -1)
|
||||
grad_weight = np.swapaxes(input, -2, -1) @ grad_output
|
||||
return grad_input, grad_weight
|
||||
register('dot', Dot)
|
||||
register('matmul', Matmul)
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ class Max(Function):
|
||||
return binary_op(ctx, 'a*b', ret2, GPUBuffer(shape, grad_output))
|
||||
register('max', Max, device=Device.GPU)
|
||||
|
||||
class Dot(Function):
|
||||
class Matmul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
assert input.shape[-1] == weight.shape[-2]
|
||||
@@ -299,7 +299,7 @@ class Dot(Function):
|
||||
i32(1), msize, isize, i32(1), osize, osize)
|
||||
|
||||
return grad_input, grad_weight
|
||||
register('dot', Dot, device=Device.GPU)
|
||||
register('matmul', Matmul, device=Device.GPU)
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
|
||||
@@ -207,8 +207,8 @@ class Tensor:
|
||||
def pad2d(self, padding):
|
||||
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
|
||||
|
||||
def matmul(self, w):
|
||||
return self.dot(w)
|
||||
def dot(self, w):
|
||||
return self.matmul(w)
|
||||
|
||||
def mean(self, axis=None):
|
||||
out = self.sum(axis=axis)
|
||||
@@ -304,7 +304,7 @@ def register(name, fxn, device=Device.CPU):
|
||||
return f.apply(f, *x, **kwargs)
|
||||
setattr(Tensor, name, dispatch)
|
||||
# TODO: div is a second class op, so it doesn't work here
|
||||
if name in ['add', 'sub', 'mul', 'pow']:
|
||||
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
|
||||
setattr(Tensor, f"__{name}__", dispatch)
|
||||
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))
|
||||
setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self))
|
||||
|
||||
Reference in New Issue
Block a user