mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add support for padding='same' in nn.conv (#6975)
* add support for padding='same' in nn.conv * express concisely * simplify loop * test same padding with dilation and conv1d * fix bad indentation * make loop one liner
This commit is contained in:
@@ -172,6 +172,66 @@ class TestNN(unittest.TestCase):
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_conv1d_same_padding(self):
|
||||
BS, C1, W = 8, 3, 32
|
||||
C2, K, S, P = 16, 3, 1, 'same'
|
||||
|
||||
# create in tinygrad
|
||||
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def _run_conv2d_same_padding_test(self, BS, C1, C2, H, W, K, S, padding='same', D=1):
|
||||
# create in tinygrad
|
||||
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_conv2d_same_padding_odd_input(self):
|
||||
BS, C1, H, W = 16, 16, 29, 31
|
||||
C2, K, S, P = 32, 4, 1, 'same'
|
||||
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
|
||||
|
||||
def test_conv2d_same_padding_large_kernel(self):
|
||||
BS, C1, H, W = 16, 16, 28, 33
|
||||
C2, K, S, P = 32, 9, 1, 'same'
|
||||
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
|
||||
|
||||
def test_conv2d_same_padding_with_dilation(self):
|
||||
BS, C1, H, W = 16, 3, 28, 28
|
||||
C2, K, S, P, D = 32, 3, 1, 'same', 3
|
||||
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P, D)
|
||||
|
||||
def test_conv2d_same_padding_invalid_stride(self):
|
||||
C1, C2, K, S, P = 16, 32, 2, 2, 'same'
|
||||
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
def test_conv2d_same_padding_invalid_padding_str(self):
|
||||
C1, C2, K, S, P = 16, 32, 2, 1, 'not_same'
|
||||
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
@unittest.skip("Takes too long to compile for Compiled backends")
|
||||
def test_conv2d_winograd(self):
|
||||
BS, C1, H, W = 2, 8, 16, 16
|
||||
|
||||
Reference in New Issue
Block a user