merge TestRollEdgeCases into test_ops (#11321)

This commit is contained in:
chenyu
2025-07-22 10:55:57 -04:00
committed by GitHub
parent 1d8b3e9d1c
commit fb42c84365
2 changed files with 6 additions and 18 deletions

View File

@@ -86,9 +86,11 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
class TestOps(unittest.TestCase):
def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, forward_only=False, exact=False, vals=None, low=-1.5, high=1.5):
def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn=None, expected=None, forward_only=False, exact=False, vals=None, low=-1.5, high=1.5):
if getenv("MOCKGPU") and Device.DEFAULT == "NV": self.skipTest('helper_test_exception fails in CI CUDA')
ts, tst = prepare_test_op(low, high, shps, vals, forward_only)
if tinygrad_fxn is None:
tinygrad_fxn = torch_fxn
with self.assertRaises(expected) as torch_cm:
torch_fxn(*ts)
with self.assertRaises(expected) as tinygrad_cm:
@@ -1939,7 +1941,7 @@ class TestOps(unittest.TestCase):
def test_roll(self):
helper_test_op([(2, 4)], lambda x: x.roll(1))
helper_test_op([(2, 4)], lambda x: x.roll((1,)))
self.helper_test_exception([(2, 4)], lambda x: x.roll((1,2)), lambda x: x.roll((1,2)), expected=RuntimeError)
self.helper_test_exception([(2, 4)], lambda x: x.roll((1, 2)), expected=RuntimeError)
helper_test_op([(2, 4)], lambda x: x.roll(1, 0))
helper_test_op([(2, 4)], lambda x: x.roll(-1, 0))
helper_test_op([(2, 4)], lambda x: x.roll(shifts=(2, 1), dims=(0, 1)))
@@ -1953,6 +1955,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(2, 4)], lambda x: x.roll(0, 0))
helper_test_op([(2, 4, 6)], lambda x: x.roll(shifts=(0, 0), dims=(0, 1)))
helper_test_op([(2, 4, 6)], lambda x: x.roll(shifts=(0, 2), dims=(0, 1)))
self.helper_test_exception([(3, 3)], lambda x: x.roll(shifts=1, dims=(0, 1)), expected=RuntimeError)
self.helper_test_exception([(10,)], lambda x: x.roll(shifts=(1, 2), dims=0), expected=RuntimeError)
def test_detach(self):
helper_test_op([(4,3,6,6)], lambda x: x.detach(), forward_only=True)