From fef6c82491a1d873564b5b2bc67f94eda47112d9 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 15 Jun 2022 11:38:23 -0700 Subject: [PATCH] wow dilation support was simple --- test/test_ops.py | 23 +++++++++++++++-------- tinygrad/llops/ops_cpu.py | 2 +- tinygrad/llops/ops_gpu.py | 10 +++++----- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 079c28c2ac..7cdd7f5ebb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -268,19 +268,26 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) + def test_dilated_conv2d_forward(self): + bs = 4 + cin = 3 + H,W = 3,3 + for d in [2, (2,1)]: + with self.subTest(dilation := d): + helper_test_op([(bs,cin,11,28), (4,cin,H,W)], + lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(), + lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4, forward_only=True) + @unittest.skipUnless(Device.DEFAULT == Device.TORCH, "Not Implemented") def test_dilated_conv2d(self): bs = 4 cin = 3 H,W = 3,3 - with self.subTest(dilation := 2): - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(), - lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4) - with self.subTest(dilation := (2,1)): - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(), - lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4) + for d in [2, (2,1)]: + with self.subTest(dilation := d): + helper_test_op([(bs,cin,11,28), (4,cin,H,W)], + lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(), + lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4) def test_maxpool2d(self): for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 9522246196..965d3a7621 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -64,7 +64,7 @@ def get_tx(x, C): gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3]) return np.lib.stride_tricks.as_strided(gx, shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W), - strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, *gx.strides[3:5]), + strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, gx.strides[3]*C.dy, gx.strides[4]*C.dx), writeable=False, ) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 84cacf6fef..7593abd48f 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -132,7 +132,7 @@ def conv(x,w,ret,C): # output = (bs, groups, rcout, oy, ox) conv_prg = clbuild("conv", """ __kernel void conv(__global const float *input, __global const float *weight, __global float *output, - int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) { + int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs, int dx, int dy) { int B = get_global_id(0)/(groups*rcout); // range 0-bs int g = (get_global_id(0)/rcout)%groups; @@ -145,15 +145,15 @@ def conv(x,w,ret,C): float acc = 0.0; for (int ci = 0; ci < cin; ci++) { - for (int y = IY; y < IY+H; y++) { for (int x = IX; x < IX+W; x++) { - acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \ - weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)]; + for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) { + acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (y*dy+IY)*ix + (x*dx+IX)] * \ + weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x]; } } } output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc; }""") - conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in C[0:12]]) + conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in list(C[0:12])+[C.dx, C.dy]]) # tensx = (bs, groups*cin, iy, ix) # tensw = (groups*rcout, cin, H, W)