From c158e3c9889fd3f0c7a03d395adc6eeb88f86f0b Mon Sep 17 00:00:00 2001 From: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com> Date: Tue, 2 Dec 2025 22:02:47 +0100 Subject: [PATCH] add cifar gated uop_given_valid regression test (#13536) --- test/unit/test_uop_symbolic.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 3b3f185b24..c123c9f711 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -6,7 +6,7 @@ from tinygrad.dtype import dtypes, ConstType, DType, Invalid from tinygrad.codegen import full_rewrite from tinygrad.helpers import Context from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer -from tinygrad.uop.symbolic import sym, commutative +from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid from tinygrad.uop.validate import uops_to_z3 def check_uop_against_string(self, v:UOp, s:str): @@ -1029,6 +1029,25 @@ class TestSymbolicRealWorld(unittest.TestCase): self.assertIn(idx.render(), ("(lidx3+((lidx5+1)//16*802816+(lidx5+1)%16*49+gidx0*3211264+gidx1*784+gidx2*8+lidx4*100352)+2207744)",)) +class TestGatedUopGivenValid(unittest.TestCase): + def test_invalid_gate_simplifies_index(self): + r0 = Variable("r0", 0, 2) + + idx:UOp = (r0 < 3).where((r0 + uconst(-1)) // uconst(3), UOp.invalid()) + idx = graph_rewrite(idx, pm_simplify_valid) + self.assertEqual(idx, (r0 < 3).where(uconst(0), UOp.invalid())) + + def test_invalid_gate_simplifies_vectorize(self): + r0 = Variable("r0", 0, 2) + + idx0 = (r0 + uconst(-1)) // uconst(3) + idx1 = r0 % uconst(3) + idx:UOp = (r0 < 3).where(UOp(Ops.VECTORIZE, dtypes.index.vec(2), (idx0, idx1)), UOp.invalid()) + idx = graph_rewrite(idx, pm_simplify_valid) + # NOTE: independent simplification: (r0-1)//3 -> 0, r0%3 -> r0 when r0 in [0,2] + expected_vec = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (uconst(0), r0)) + self.assertEqual(idx, (r0 < 3).where(expected_vec, UOp.invalid())) + class TestBounds(unittest.TestCase): def test_unrolled_arange(self): # #include