diff --git a/README.md b/README.md index d9461a59fd..e52f9e7005 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ You need to support 14 basic ops: Relu, Log, Exp # unary ops Add, Sub, Mul, Pow # binary ops (with broadcasting) Sum, Max # reduce ops (with axis argument) -Dot, Conv2D # matrix multiplication and conv +Matmul, Conv2D # matrix multiplication and conv Reshape, Transpose, Slice # moving things around ops ``` diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index f645f003ba..70bb8a218f 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -89,7 +89,7 @@ register('max', Max) # ************* GEMM ************* -class Dot(Function): +class Matmul(Function): @staticmethod def forward(ctx, input, weight): ctx.save_for_backward(input, weight) @@ -101,7 +101,7 @@ class Dot(Function): grad_input = grad_output @ np.swapaxes(weight, -2, -1) grad_weight = np.swapaxes(input, -2, -1) @ grad_output return grad_input, grad_weight -register('dot', Dot) +register('matmul', Matmul) # ************* movement ops ************* diff --git a/tinygrad/ops_gpu.py b/tinygrad/ops_gpu.py index 88aa0593d2..51746cd173 100644 --- a/tinygrad/ops_gpu.py +++ b/tinygrad/ops_gpu.py @@ -246,7 +246,7 @@ class Max(Function): return binary_op(ctx, 'a*b', ret2, GPUBuffer(shape, grad_output)) register('max', Max, device=Device.GPU) -class Dot(Function): +class Matmul(Function): @staticmethod def forward(ctx, input, weight): assert input.shape[-1] == weight.shape[-2] @@ -299,7 +299,7 @@ class Dot(Function): i32(1), msize, isize, i32(1), osize, osize) return grad_input, grad_weight -register('dot', Dot, device=Device.GPU) +register('matmul', Matmul, device=Device.GPU) # ************* movement ops ************* diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c561bcc9f3..6dd73a3dd9 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -207,8 +207,8 @@ class Tensor: def pad2d(self, padding): return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]] - def matmul(self, w): - return self.dot(w) + def dot(self, w): + return self.matmul(w) def mean(self, axis=None): out = self.sum(axis=axis) @@ -304,7 +304,7 @@ def register(name, fxn, device=Device.CPU): return f.apply(f, *x, **kwargs) setattr(Tensor, name, dispatch) # TODO: div is a second class op, so it doesn't work here - if name in ['add', 'sub', 'mul', 'pow']: + if name in ['add', 'sub', 'mul', 'pow', 'matmul']: setattr(Tensor, f"__{name}__", dispatch) setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x))) setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self))