diff --git a/test/test_ops.py b/test/test_ops.py index d170c1e78d..0773c06436 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -74,7 +74,7 @@ class TestOps(unittest.TestCase): with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W): helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), gpu=self.gpu, grad_rtol=1e-5, forward_only=self.gpu) + lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), gpu=self.gpu, grad_rtol=1e-5) def test_strided_conv2d(self): bs = 4 diff --git a/tinygrad/opsgpu.py b/tinygrad/opsgpu.py index ba2a8b8762..d50a0c15c2 100644 --- a/tinygrad/opsgpu.py +++ b/tinygrad/opsgpu.py @@ -18,7 +18,7 @@ def buffer_zeros(ctx, shape): def buffer_like(ctx, x): return buffer_new(ctx, x.shape) -@functools.lru_cache +@functools.lru_cache() def clbuild(cl_ctx, prg): return cl.Program(cl_ctx, prg).build() @@ -434,6 +434,8 @@ class Conv2D(Function): assert cin*ctx.groups == cin_ assert cout % ctx.groups == 0 rcout = cout//ctx.groups + + ctx.save_for_backward(x,w) # output buffer ret = buffer_new(ctx, (bs, cout, oy, ox)) @@ -478,7 +480,66 @@ class Conv2D(Function): @staticmethod def backward(ctx, grad_output): - raise Exception("not implemented") + bs,_,oy,ox = grad_output.shape + x, w = ctx.saved_tensors + cout,cin,H,W = w.shape + ys,xs = ctx.stride + bs,cin_,iy,ix = x.shape + oy,ox = (iy-(H-ys))//ys, (ix-(W-xs))//xs + assert cin*ctx.groups == cin_ + assert cout % ctx.groups == 0 + rcout = cout//ctx.groups + dx = buffer_zeros(ctx, (bs, cin_, iy, ix)) + dw = buffer_new(ctx, (cout, cin, H, W)) + + prg = clbuild(ctx.cl_ctx, """ + __kernel void convw(__global const float *tensx, __global const float *ggg, __global float *dw, + int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) { + + int g = get_global_id(0)/(rcout*cin) ; // range 0-groups + int c = (get_global_id(0)/(cin)) %rcout; // range 0-rcout + int ci = get_global_id(0) % cin; // range 0-cin + int y = get_global_id(1); // range 0-H + int x = get_global_id(2); // range 0-W + + // tensx = (bs, groups*cin, iy, ix) + // tensw = (groups*rcout, cin, H, W) + // ggg = (bs, groups*rout, oy, ox) + float acc = 0.0; + for (int Y = 0; Y < oy; Y++) { + for (int X = 0; X < ox; X++) { + for (int B = 0; B < bs; B++) { + acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X]*tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x]; + } + } + } + dw[get_global_id(0)*H*W + y*W + x] = acc; + } + __kernel void convx(__global const float *tensw, __global const float *ggg, __global float *dx, + int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) { + + int B = get_global_id(0); + int g = get_global_id(1); + int ci = get_global_id(2); + + for (int c = 0; c < rcout; c++) { + for (int Y = 0; Y < oy; Y++) { + for (int X = 0; X < ox; X++) { + for (int y = 0; y < H; y++) { + for (int x = 0; x < W; x++) { + dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x]+= ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X]*tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x]; + } + } + } + } + } + } + """) + + prg.convw(ctx.cl_queue, [ctx.groups*rcout*cin, H, W], None, x, grad_output, dw, np.int32(H), np.int32(W), np.int32(ctx.groups), + np.int32(rcout), np.int32(cin), np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix), np.int32(ys), np.int32(xs), np.int32(bs)) + prg.convx(ctx.cl_queue, [bs, ctx.groups, cin], None, w, grad_output, dx, np.int32(H), np.int32(W), np.int32(ctx.groups), + np.int32(rcout), np.int32(cin), np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix), np.int32(ys), np.int32(xs), np.int32(bs)) + return dx, dw register('conv2d', Conv2D, gpu=True) -