mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Cherry backprop for conv2d (#281)
* quick math: 0 + x = x. * gradient w.r.t. x using cherry for conv * gradient w.r.t. w for conv on cherry but doing vector dot products * small optimization * [cherry] optimize conv backpass for large channel count * get rid of numpy einsum
This commit is contained in:
@@ -164,6 +164,16 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_large_input_conv2d(self):
|
||||
bs = 4
|
||||
cin = 16
|
||||
groups = 1
|
||||
H = 5
|
||||
W = 2
|
||||
helper_test_op([(bs,cin,64,64), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_grouped_conv2d(self):
|
||||
groups = 2
|
||||
@@ -179,7 +189,7 @@ class TestOps(unittest.TestCase):
|
||||
H,W = 3,3
|
||||
helper_test_op([(bs,cin,11,28), (groups*cout,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=True)
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
|
||||
Reference in New Issue
Block a user