mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add flip
This commit is contained in:
@@ -161,6 +161,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)))
|
||||
|
||||
def test_flip(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,)))
|
||||
|
||||
def test_flatten(self):
|
||||
for axis in range(3):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(axis))
|
||||
@@ -196,7 +202,8 @@ class TestOps(unittest.TestCase):
|
||||
W = 2
|
||||
helper_test_op([(bs,cin,64,64), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
# needed to relax tolerance on NVIDIA
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-3, grad_rtol=1e-5)
|
||||
|
||||
def test_simple_grouped_conv2d(self):
|
||||
bs = 1
|
||||
|
||||
Reference in New Issue
Block a user