add cifar gated uop_given_valid regression test (#13536)

This commit is contained in:
Roelof van Dijk
2025-12-02 22:02:47 +01:00
committed by GitHub
parent e329baffa7
commit c158e3c988

View File

@@ -6,7 +6,7 @@ from tinygrad.dtype import dtypes, ConstType, DType, Invalid
from tinygrad.codegen import full_rewrite
from tinygrad.helpers import Context
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
from tinygrad.uop.symbolic import sym, commutative
from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid
from tinygrad.uop.validate import uops_to_z3
def check_uop_against_string(self, v:UOp, s:str):
@@ -1029,6 +1029,25 @@ class TestSymbolicRealWorld(unittest.TestCase):
self.assertIn(idx.render(),
("(lidx3+((lidx5+1)//16*802816+(lidx5+1)%16*49+gidx0*3211264+gidx1*784+gidx2*8+lidx4*100352)+2207744)",))
class TestGatedUopGivenValid(unittest.TestCase):
def test_invalid_gate_simplifies_index(self):
r0 = Variable("r0", 0, 2)
idx:UOp = (r0 < 3).where((r0 + uconst(-1)) // uconst(3), UOp.invalid())
idx = graph_rewrite(idx, pm_simplify_valid)
self.assertEqual(idx, (r0 < 3).where(uconst(0), UOp.invalid()))
def test_invalid_gate_simplifies_vectorize(self):
r0 = Variable("r0", 0, 2)
idx0 = (r0 + uconst(-1)) // uconst(3)
idx1 = r0 % uconst(3)
idx:UOp = (r0 < 3).where(UOp(Ops.VECTORIZE, dtypes.index.vec(2), (idx0, idx1)), UOp.invalid())
idx = graph_rewrite(idx, pm_simplify_valid)
# NOTE: independent simplification: (r0-1)//3 -> 0, r0%3 -> r0 when r0 in [0,2]
expected_vec = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (uconst(0), r0))
self.assertEqual(idx, (r0 < 3).where(expected_vec, UOp.invalid()))
class TestBounds(unittest.TestCase):
def test_unrolled_arange(self):
# #include <metal_stdlib>