diff --git a/examples/efficientnet.py b/examples/efficientnet.py index e50f0fd8a0..99ba8dedc7 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -119,8 +119,10 @@ if __name__ == "__main__": # load cat image and preprocess from PIL import Image img = Image.open(io.BytesIO(fetch("https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg"))) - img = img.resize((224, 224)) - img = np.moveaxis(np.array(img), [2,0,1], [0,1,2]) + img = img.resize((398, 224)) + img = np.array(img) + img = img[:, 87:-87] + img = np.moveaxis(img, [2,0,1], [0,1,2]) img = img.astype(np.float32).reshape(1,3,224,224) img /= 256 img -= np.array([0.485, 0.456, 0.406]).reshape((1,-1,1,1)) diff --git a/test/test_ops.py b/test/test_ops.py index e110c76173..0d47e92f16 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -48,11 +48,12 @@ class TestOps(unittest.TestCase): def test_conv2d(self): for bs in [1,8]: for cin in [1,3]: - for H in [2,5]: - for W in [2,3,5]: - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w).relu(), - lambda x,w: Tensor.conv2d(x,w).relu(), atol=2e-5, grad_atol=2e-6) + for groups in [1,3] if cin == 3 else [1]: + for H in [2,5]: + for W in [2,3,5]: + helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)], + lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), + lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=2e-5, grad_atol=2e-6) def test_strided_conv2d(self): bs = 4 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d2591de495..6a18708657 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -155,21 +155,23 @@ class Conv2D(Function): def forward(ctx, x, w, stride=1, groups=1): if type(ctx.stride) == int: ctx.stride = (ctx.stride, ctx.stride) - cout,cin,H,W = w.shape - if groups > 1: - w = np.repeat(w, groups, axis=1) / groups - tw = w.reshape(cout, -1).T ys,xs = ctx.stride - bs,oy,ox = x.shape[0], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs + bs,cin_,oy,ox = x.shape[0], x.shape[1], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs + assert cin*ctx.groups == cin_ + assert cout % ctx.groups == 0 + rcout = cout//ctx.groups ctx.save_for_backward(x, w) ret = np.zeros((bs, cout, oy, ox), dtype=w.dtype) - for Y in range(oy): - for X in range(ox): - iY,iX = Y*ys, X*xs - tx = x[:, :, iY:iY+H, iX:iX+W].reshape(bs, -1) - ret[:, :, Y, X] = tx.dot(tw) + + for g in range(ctx.groups): + tw = w[g*rcout:(g*rcout+rcout)].reshape(rcout, -1).T + for Y in range(oy): + for X in range(ox): + iY,iX = Y*ys, X*xs + tx = x[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W].reshape(bs, -1) + ret[:, g*rcout:(g*rcout+rcout), Y, X] += tx.dot(tw) return ret @staticmethod @@ -177,17 +179,19 @@ class Conv2D(Function): bs,_,oy,ox = grad_output.shape x, w = ctx.saved_tensors cout,cin,H,W = w.shape - tw = w.reshape(cout, -1) ys,xs = ctx.stride + rcout = cout//ctx.groups dx, dw = np.zeros_like(x), np.zeros_like(w) - for Y in range(grad_output.shape[2]): - for X in range(grad_output.shape[3]): - iY,iX = Y*ys, X*xs - gg = grad_output[:, :, Y, X] - tx = x[:, :, iY:iY+H, iX:iX+W].reshape(x.shape[0], -1) - dw += gg.T.dot(tx).reshape(dw.shape) - dx[:, :, iY:iY+H, iX:iX+W] += gg.dot(tw).reshape(dx.shape[0], dx.shape[1], H, W) + for g in range(ctx.groups): + tw = w[g*rcout:(g*rcout+rcout)].reshape(rcout, -1) + for Y in range(grad_output.shape[2]): + for X in range(grad_output.shape[3]): + iY,iX = Y*ys, X*xs + gg = grad_output[:, g*rcout:(g*rcout+rcout), Y, X] + tx = x[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W].reshape(x.shape[0], -1) + dw[g*rcout:(g*rcout+rcout)] += gg.T.dot(tx).reshape((rcout,cin,H,W)) + dx[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W] += gg.dot(tw).reshape(dx.shape[0], cin, H, W) return dx, dw register('conv2d', Conv2D)