mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
clean up more conv tests (#11510)
This commit is contained in:
@@ -125,43 +125,22 @@ class TestNN(unittest.TestCase):
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_conv1d(self):
|
||||
BS, C1, DIMS = 4, 16, [224//4]
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
self._test_conv(Conv1d, torch.nn.Conv1d, BS, C1, DIMS, C2, K, S, P)
|
||||
|
||||
def test_conv2d(self):
|
||||
BS, C1, DIMS = 4, 16, [224//4, 224//4]
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
self._test_conv(Conv2d, torch.nn.Conv2d, BS, C1, DIMS, C2, K, S, P)
|
||||
def test_conv1d(self): self._test_conv(Conv1d, torch.nn.Conv1d, BS=4, C1=16, DIMS=[224//4], C2=64, K=7, S=2, P=1)
|
||||
def test_conv2d(self): self._test_conv(Conv2d, torch.nn.Conv2d, BS=4, C1=16, DIMS=[224//4, 224//4], C2=64, K=7, S=2, P=1)
|
||||
|
||||
def test_conv1d_same_padding(self):
|
||||
BS, C1, DIMS = 8, 3, [32]
|
||||
C2, K, S, P = 16, 3, 1, 'same'
|
||||
self._test_conv(Conv1d, torch.nn.Conv1d, BS, C1, DIMS, C2, K, S, P)
|
||||
|
||||
self._test_conv(Conv1d, torch.nn.Conv1d, BS=8, C1=3, DIMS=[32], C2=16, K=3, S=1, P='same')
|
||||
def test_conv2d_same_padding_odd_input(self):
|
||||
BS, C1, DIMS = 16, 16, [29, 31]
|
||||
C2, K, S, P = 32, 5, 1, 'same'
|
||||
self._test_conv(Conv2d, torch.nn.Conv2d, BS, C1, DIMS, C2, K, S, P)
|
||||
|
||||
self._test_conv(Conv2d, torch.nn.Conv2d, BS=16, C1=16, DIMS=[29, 31], C2=32, K=5, S=1, P='same')
|
||||
def test_conv2d_same_padding_large_kernel(self):
|
||||
BS, C1, DIMS = 16, 16, [28, 33]
|
||||
C2, K, S, P = 32, 9, 1, 'same'
|
||||
self._test_conv(Conv2d, torch.nn.Conv2d, BS, C1, DIMS, C2, K, S, P)
|
||||
|
||||
self._test_conv(Conv2d, torch.nn.Conv2d, BS=16, C1=16, DIMS=[28, 33], C2=32, K=9, S=1, P='same')
|
||||
def test_conv2d_same_padding_with_dilation(self):
|
||||
BS, C1, DIMS = 16, 3, [28, 28]
|
||||
C2, K, S, P, D = 32, 3, 1, 'same', 3
|
||||
self._test_conv(Conv2d, torch.nn.Conv2d, BS, C1, DIMS, C2, K, S, P, D)
|
||||
self._test_conv(Conv2d, torch.nn.Conv2d, BS=16, C1=3, DIMS=[28, 28], C2=32, K=3, S=1, P='same', D=3)
|
||||
|
||||
def test_conv2d_same_padding_invalid_stride(self):
|
||||
C1, C2, K, S, P = 16, 32, 2, 2, 'same'
|
||||
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
self.assertRaises(ValueError, Conv2d, in_channels=16, out_channels=32, kernel_size=2, stride=2, padding='same')
|
||||
def test_conv2d_same_padding_invalid_padding_str(self):
|
||||
C1, C2, K, S, P = 16, 32, 2, 1, 'not_same'
|
||||
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
self.assertRaises(ValueError, Conv2d, in_channels=16, out_channels=32, kernel_size=2, stride=1, padding='not_same')
|
||||
|
||||
@unittest.skip("Takes too long to compile for Compiled backends")
|
||||
def test_conv2d_winograd(self):
|
||||
@@ -200,14 +179,9 @@ class TestNN(unittest.TestCase):
|
||||
np.testing.assert_allclose(gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_conv_transpose1d(self):
|
||||
BS, C1, DIMS = 4, 16, [224//4]
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
self._test_conv(ConvTranspose1d, torch.nn.ConvTranspose1d, BS, C1, DIMS, C2, K, S, P)
|
||||
|
||||
self._test_conv(ConvTranspose1d, torch.nn.ConvTranspose1d, BS=4, C1=16, DIMS=[224//4], C2=64, K=7, S=2, P=1)
|
||||
def test_conv_transpose2d(self):
|
||||
BS, C1, DIMS = 4, 16, [224//4, 224//4]
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
self._test_conv(ConvTranspose2d, torch.nn.ConvTranspose2d, BS, C1, DIMS, C2, K, S, P)
|
||||
self._test_conv(ConvTranspose2d, torch.nn.ConvTranspose2d, BS=4, C1=16, DIMS=[224//4, 224//4], C2=64, K=7, S=2, P=1)
|
||||
|
||||
def test_groupnorm(self):
|
||||
BS, H, W, C, G = 20, 10, 10, 6, 3
|
||||
|
||||
Reference in New Issue
Block a user