mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
support multidot on GPU
This commit is contained in:
@@ -75,7 +75,6 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device)
|
||||
@cpu_only
|
||||
def test_multidot(self):
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
|
||||
def test_sum(self):
|
||||
@@ -163,7 +162,6 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), device=self.device)
|
||||
|
||||
@cpu_only
|
||||
def test_maxpool2d(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
|
||||
@@ -249,49 +249,53 @@ register('max', Max, device=Device.GPU)
|
||||
class Dot(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
assert input.shape[1] == weight.shape[0]
|
||||
isize, msize, osize = i32(input.shape[0]), i32(input.shape[1]), i32(weight.shape[1])
|
||||
ret = buffer_new(ctx, (isize, osize))
|
||||
assert input.shape[-1] == weight.shape[-2]
|
||||
cnt = input.shape[0] if len(input.shape) == 3 else 1
|
||||
isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1])
|
||||
ret = buffer_new(ctx, (isize, osize) if cnt == 1 else (cnt, isize, osize))
|
||||
|
||||
matmul = clbuild(ctx.cl_ctx, "matmul", """
|
||||
__kernel void matmul(
|
||||
__global const float *input, __global const float *weight, __global float *res,
|
||||
int is0, int is1, int msize, int ws0, int ws1, int osize
|
||||
int isize, int is0, int is1, int msize, int ws0, int ws1, int osize
|
||||
) {
|
||||
int stride = get_global_id(2);
|
||||
|
||||
int X = get_global_id(0); // isize
|
||||
int Y = get_global_id(1); // osize
|
||||
|
||||
float ret = 0.0;
|
||||
for (int x = 0; x < msize; x++) {
|
||||
ret += input[X * is0 + x * is1] * weight[Y * ws0 + x * ws1];
|
||||
ret += input[X * is0 + x * is1 + isize*msize*stride] *
|
||||
weight[Y * ws0 + x * ws1 + msize*osize*stride];
|
||||
}
|
||||
|
||||
res[X * osize + Y] = ret;
|
||||
res[X * osize + Y + isize*osize*stride] = ret;
|
||||
}""")
|
||||
ctx.save_for_backward(input, weight, matmul)
|
||||
ctx.save_for_backward(input, weight, matmul, cnt)
|
||||
|
||||
# (isize,msize) x (msize,osize) = (isize,osize)
|
||||
matmul(ctx.cl_queue, [isize, osize], None,
|
||||
input.cl, weight.cl, ret.cl,
|
||||
matmul(ctx.cl_queue, [isize, osize, cnt], None,
|
||||
input.cl, weight.cl, ret.cl, isize,
|
||||
msize, i32(1), msize, i32(1), osize, osize)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, matmul = ctx.saved_tensors
|
||||
isize, msize, osize = i32(input.shape[0]), i32(input.shape[1]), i32(weight.shape[1])
|
||||
input, weight, matmul, cnt = ctx.saved_tensors
|
||||
isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1])
|
||||
|
||||
grad_input = buffer_new(ctx, input.shape)
|
||||
grad_weight = buffer_new(ctx, weight.shape)
|
||||
|
||||
# (isize,osize) x (msize,osize) = (isize,msize)
|
||||
matmul(ctx.cl_queue, [isize, msize], None,
|
||||
grad_output.cl, weight.cl, grad_input.cl,
|
||||
matmul(ctx.cl_queue, [isize, msize, cnt], None,
|
||||
grad_output.cl, weight.cl, grad_input.cl, isize,
|
||||
osize, i32(1), osize, osize, i32(1), msize)
|
||||
|
||||
# (isize,msize) x (isize,osize) = (msize,osize)
|
||||
matmul(ctx.cl_queue, [msize, osize], None,
|
||||
input.cl, grad_output.cl, grad_weight.cl,
|
||||
matmul(ctx.cl_queue, [msize, osize, cnt], None,
|
||||
input.cl, grad_output.cl, grad_weight.cl, msize,
|
||||
i32(1), msize, isize, i32(1), osize, osize)
|
||||
|
||||
return grad_input, grad_weight
|
||||
|
||||
Reference in New Issue
Block a user