Dot -> Matmul

This commit is contained in:
George Hotz
2020-12-30 10:41:51 -05:00
parent 10fc3ff5b9
commit 2d44bf7f1a
4 changed files with 8 additions and 8 deletions

View File

@@ -111,7 +111,7 @@ You need to support 14 basic ops:
Relu, Log, Exp # unary ops
Add, Sub, Mul, Pow # binary ops (with broadcasting)
Sum, Max # reduce ops (with axis argument)
Dot, Conv2D # matrix multiplication and conv
Matmul, Conv2D # matrix multiplication and conv
Reshape, Transpose, Slice # moving things around ops
```

View File

@@ -89,7 +89,7 @@ register('max', Max)
# ************* GEMM *************
class Dot(Function):
class Matmul(Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
@@ -101,7 +101,7 @@ class Dot(Function):
grad_input = grad_output @ np.swapaxes(weight, -2, -1)
grad_weight = np.swapaxes(input, -2, -1) @ grad_output
return grad_input, grad_weight
register('dot', Dot)
register('matmul', Matmul)
# ************* movement ops *************

View File

@@ -246,7 +246,7 @@ class Max(Function):
return binary_op(ctx, 'a*b', ret2, GPUBuffer(shape, grad_output))
register('max', Max, device=Device.GPU)
class Dot(Function):
class Matmul(Function):
@staticmethod
def forward(ctx, input, weight):
assert input.shape[-1] == weight.shape[-2]
@@ -299,7 +299,7 @@ class Dot(Function):
i32(1), msize, isize, i32(1), osize, osize)
return grad_input, grad_weight
register('dot', Dot, device=Device.GPU)
register('matmul', Matmul, device=Device.GPU)
# ************* movement ops *************

View File

@@ -207,8 +207,8 @@ class Tensor:
def pad2d(self, padding):
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
def matmul(self, w):
return self.dot(w)
def dot(self, w):
return self.matmul(w)
def mean(self, axis=None):
out = self.sum(axis=axis)
@@ -304,7 +304,7 @@ def register(name, fxn, device=Device.CPU):
return f.apply(f, *x, **kwargs)
setattr(Tensor, name, dispatch)
# TODO: div is a second class op, so it doesn't work here
if name in ['add', 'sub', 'mul', 'pow']:
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
setattr(Tensor, f"__{name}__", dispatch)
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))
setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self))