From 285621aedaa27e068759dfdfa90b82885daa6fe2 Mon Sep 17 00:00:00 2001 From: Evan Mays Date: Sat, 30 Oct 2021 19:12:19 -0400 Subject: [PATCH] Cherry backprop for conv2d (#281) * quick math: 0 + x = x. * gradient w.r.t. x using cherry for conv * gradient w.r.t. w for conv on cherry but doing vector dot products * small optimization * [cherry] optimize conv backpass for large channel count * get rid of numpy einsum --- extra/ops_cherry.py | 31 ++++++++++++++++++++++++------- test/test_ops.py | 12 +++++++++++- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/extra/ops_cherry.py b/extra/ops_cherry.py index ce8c1bac0d..f70570d708 100644 --- a/extra/ops_cherry.py +++ b/extra/ops_cherry.py @@ -264,19 +264,36 @@ class Conv2D(Function): ggg = grad_output.reshape(bs,ctx.groups,rcout,oy,ox) gdw = np.zeros((ctx.groups,rcout,cin,H,W), dtype=tx.dtype) - for g in range(ctx.groups): - #'ikYX,ijYXyx -> kjyx' - gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3))) - # needs to be optimized + if cin >= 16: + # optimize for large channel count + for g in range(ctx.groups): + #'ikYX,ijYXyx -> kjyx' + for i in range(ggg[:,g].shape[1]): + for m in range(tx[:,g].shape[4]): + for n in range(tx[:,g].shape[5]): + # Use transposes to ensure reshape keeps the correct dimension (channel dimension) when multiple dimensions have the same size + big_matrix = np.transpose(tx[:,g][:, :, :, :, m, n], (1, 0, 2, 3)).reshape(tx[:,g].shape[1], -1).T + gdw[g][i, :, m, n] = cherry_matmul(ggg[:,g][:,i].reshape(1, -1), big_matrix).flatten() + else: + # unoptimized + for g in range(ctx.groups): + #'ikYX,ijYXyx -> kjyx' + for i in range(ggg[:,g].shape[1]): + for j in range(tx[:,g].shape[1]): + for m in range(tx[:,g].shape[4]): + big_matrix = tx[:,g][:,j, :, :, m].reshape(-1, tx[:,g].shape[5]) + gdw[g][i, j, m] = cherry_matmul(ggg[:,g][:,i].reshape(1, -1), big_matrix).flatten() + + # needs to be optimized separately for large oy and ox, versus large ctx.groups gdx = np.zeros((bs,ctx.groups,cin,OY,OX), dtype=tx.dtype) for k in range(oy*ox): Y, X = k//ox, k%ox iY,iX = Y*ys, X*xs - #gdx[:,:,: , iY:iY+H, iX:iX+W] += np.einsum('igk,gkjyx->igjyx', ggg[:,:,:,Y,X], tw) + big_matrix = [] for g in range(ctx.groups): - tg = np.dot(ggg[:,g,:,Y,X].reshape(bs, -1), tw[g].reshape(rcout, -1)) - gdx[:, g, :, iY:iY+H, iX:iX+W] += tg.reshape((bs, cin, H, W)) + big_matrix.append(cherry_matmul(ggg[:,g,:,Y,X].reshape(bs, -1), tw[g].reshape(rcout, -1)).reshape((bs, cin, H, W))) + gdx[:, :, :, iY:iY+H, iX:iX+W] = cherry_binop(gdx[:, :, :, iY:iY+H, iX:iX+W], np.array(np.transpose(big_matrix, (1, 0, 2, 3, 4))), BinaryOps.ADD) return gdx.reshape((bs, ctx.groups*cin, OY, OX)), gdw.reshape((ctx.groups*rcout, cin, H, W)) diff --git a/test/test_ops.py b/test/test_ops.py index 7fe1bd5ee1..353eaca6c6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -164,6 +164,16 @@ class TestOps(unittest.TestCase): helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) + + def test_large_input_conv2d(self): + bs = 4 + cin = 16 + groups = 1 + H = 5 + W = 2 + helper_test_op([(bs,cin,64,64), (6,cin//groups,H,W)], + lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), + lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) def test_grouped_conv2d(self): groups = 2 @@ -179,7 +189,7 @@ class TestOps(unittest.TestCase): H,W = 3,3 helper_test_op([(bs,cin,11,28), (groups*cout,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=True) + lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) def test_strided_conv2d(self): bs = 4