mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix kernel_size bug, name like torch, add test
This commit is contained in:
@@ -32,7 +32,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7):
|
||||
|
||||
class TestOps(unittest.TestCase):
|
||||
def test_conv2d(self):
|
||||
for bs in [1,128]:
|
||||
for bs in [1,8]:
|
||||
for cin in [1,3]:
|
||||
for H in [2,5]:
|
||||
for W in [2,3,5]:
|
||||
@@ -43,6 +43,12 @@ class TestOps(unittest.TestCase):
|
||||
def test_maxpool2x2(self):
|
||||
helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, (2,2)), Tensor.max_pool2d)
|
||||
|
||||
def test_maxpool_sizes(self):
|
||||
for sz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
helper_test_op([(32,2,110,28)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=sz),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=sz))
|
||||
|
||||
def test_avgpool2x2(self):
|
||||
helper_test_op([(32,2,111,28)], lambda x: torch.nn.functional.avg_pool2d(x, (2,2)), Tensor.avg_pool2d)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user