diff --git a/test/test_ops.py b/test/test_ops.py index 72e1c052b8..562e474254 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -14,7 +14,6 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7, gpu out = torch_fxn(*ts) ret = tinygrad_fxn(*tst) - # TODO: why so inaccurate? np.testing.assert_allclose(ret.cpu().data, out.detach().numpy(), atol=atol) if not forward_only: @@ -66,11 +65,11 @@ class TestOps(unittest.TestCase): for bs in [1,8]: for cin in [1,3]: for groups in [1,3] if cin == 3 else [1]: - for H in [2,5]: - for W in [2,3,5]: + for H in [1,2,5]: + for W in [1,2,3,5]: helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=2e-5, grad_atol=2e-6, gpu=self.gpu) + lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=2e-5, grad_atol=2e-6, gpu=self.gpu, forward_only=self.gpu) def test_strided_conv2d(self): bs = 4 diff --git a/tinygrad/opsgpu.py b/tinygrad/opsgpu.py index a76d0b5fc6..0f61f90fc4 100644 --- a/tinygrad/opsgpu.py +++ b/tinygrad/opsgpu.py @@ -304,4 +304,64 @@ class LogSoftmax(Function): return grad_input register('logsoftmax', LogSoftmax, gpu=True) +# ************* conv ops ************* + +class Conv2D(Function): + @staticmethod + def forward(ctx, x, w, stride=1, groups=1): + if type(ctx.stride) == int: + ctx.stride = (ctx.stride, ctx.stride) + cout,cin,H,W = w.shape + ys,xs = ctx.stride + bs,cin_,iy,ix = x.shape + oy,ox = (iy-(H-ys))//ys, (ix-(W-xs))//xs + assert cin*ctx.groups == cin_ + assert cout % ctx.groups == 0 + rcout = cout//ctx.groups + + # output buffer + ret = buffer_new(ctx, (bs, cout, oy, ox)) + + prg = clbuild(ctx.cl_ctx, """ + __kernel void conv(__global const float *input, __global const float *weight, __global float *output, + int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix) { + + int B = get_global_id(0); // range 0-bs + int Y = get_global_id(1); // range 0-oy + int X = get_global_id(2); // range 0-ox + + // input = (bs, groups, cin, iy, ix) + // weight = (groups, rcout, cin, H, W) + // output = (bs, groups, rcout, oy, ox) + for (int g = 0; g < groups; g++) { + for (int c = 0; c < rcout; c++) { + float acc = 0.0; + for (int ci = 0; ci < cin; ci++) { + for (int y = Y; y < Y+H; y++) { + for (int x = X; x < X+W; x++) { + acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \ + weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-Y)*W + (x-X)]; + } + } + } + output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc; + } + } + } + """) + + prg.conv(ctx.cl_queue, [bs, oy, ox], None, + x, w, ret, + np.int32(H), np.int32(W), + np.int32(groups), np.int32(rcout), np.int32(cin), + np.int32(oy), np.int32(ox), + np.int32(iy), np.int32(ix) + ) + return ret + + @staticmethod + def backward(ctx, grad_output): + raise Exception("not implemented") + +register('conv2d', Conv2D, gpu=True)