mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add cifar gated uop_given_valid regression test (#13536)
This commit is contained in:
@@ -6,7 +6,7 @@ from tinygrad.dtype import dtypes, ConstType, DType, Invalid
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
|
||||
from tinygrad.uop.symbolic import sym, commutative
|
||||
from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid
|
||||
from tinygrad.uop.validate import uops_to_z3
|
||||
|
||||
def check_uop_against_string(self, v:UOp, s:str):
|
||||
@@ -1029,6 +1029,25 @@ class TestSymbolicRealWorld(unittest.TestCase):
|
||||
self.assertIn(idx.render(),
|
||||
("(lidx3+((lidx5+1)//16*802816+(lidx5+1)%16*49+gidx0*3211264+gidx1*784+gidx2*8+lidx4*100352)+2207744)",))
|
||||
|
||||
class TestGatedUopGivenValid(unittest.TestCase):
|
||||
def test_invalid_gate_simplifies_index(self):
|
||||
r0 = Variable("r0", 0, 2)
|
||||
|
||||
idx:UOp = (r0 < 3).where((r0 + uconst(-1)) // uconst(3), UOp.invalid())
|
||||
idx = graph_rewrite(idx, pm_simplify_valid)
|
||||
self.assertEqual(idx, (r0 < 3).where(uconst(0), UOp.invalid()))
|
||||
|
||||
def test_invalid_gate_simplifies_vectorize(self):
|
||||
r0 = Variable("r0", 0, 2)
|
||||
|
||||
idx0 = (r0 + uconst(-1)) // uconst(3)
|
||||
idx1 = r0 % uconst(3)
|
||||
idx:UOp = (r0 < 3).where(UOp(Ops.VECTORIZE, dtypes.index.vec(2), (idx0, idx1)), UOp.invalid())
|
||||
idx = graph_rewrite(idx, pm_simplify_valid)
|
||||
# NOTE: independent simplification: (r0-1)//3 -> 0, r0%3 -> r0 when r0 in [0,2]
|
||||
expected_vec = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (uconst(0), r0))
|
||||
self.assertEqual(idx, (r0 < 3).where(expected_vec, UOp.invalid()))
|
||||
|
||||
class TestBounds(unittest.TestCase):
|
||||
def test_unrolled_arange(self):
|
||||
# #include <metal_stdlib>
|
||||
|
||||
Reference in New Issue
Block a user