From 9f39f6391cba92f3ca7ec14e3e9703d57fad3fb5 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Wed, 29 Oct 2025 09:14:11 +0100 Subject: [PATCH] shared_codegen_spec and fix index spec (#12967) * split shared_codegen_spec and fix index * add VCONST to program_spec and move index to shared_codegen_spec * working ignore_oob=0 * cleanup * fix spec * undo that * move barrier and special earlier * fix more spec issues * more updates * remove special from program_spec * cleanup and fixes * move more to shared * special is not in shared_spec * some comments * dont do bounds check there --- test/test_const_folding.py | 5 ++- test/test_renderer_failures.py | 4 +- test/test_uop_graph.py | 39 +++++++++-------- test/test_uops.py | 4 +- test/unit/test_linalg.py | 23 +++++----- tinygrad/uop/spec.py | 79 ++++++++++++++++++---------------- tinygrad/uop/validate.py | 32 ++++++++------ 7 files changed, 101 insertions(+), 85 deletions(-) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 184bbf274a..3a709b1908 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -1,5 +1,5 @@ import unittest, itertools, math -from tinygrad import Tensor, Device, dtypes +from tinygrad import Tensor, Device, dtypes, Context from tinygrad.dtype import DType, ConstType from tinygrad.uop.ops import Ops, UOp from tinygrad.codegen import full_rewrite_to_sink @@ -126,7 +126,8 @@ class TestBitcastConstFolding(unittest.TestCase): t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233}) def test_vec_bitcast(self): - r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0] + with Context(SPEC=0): + r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0] self.assertEqual(r.op, Ops.VECTORIZE) self.assertEqual(r.dtype, dtypes.uint32.vec(3)) self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75)) diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 9e0559d44f..8efc7006dd 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -46,7 +46,7 @@ class TestRendererFailures(unittest.TestCase): def test_gated_store_with_alu(self): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) - gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1))) + gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] @@ -57,7 +57,7 @@ class TestRendererFailures(unittest.TestCase): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0) - gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1))) + gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0] diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 38ed9ae483..704f17c40e 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -307,9 +307,10 @@ class TestUOpGraph(unittest.TestCase): for vec_size in [2, 4, 8]: consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)] vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts)) - uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)]) - for uop, const in zip(uops, consts): - self.assertEqual(uop, const) + with Context(SPEC=0): + uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)]) + for uop, const in zip(uops, consts): + self.assertEqual(uop, const) @unittest.skip("no longer testable standalone") def test_wmma_vectorize_fold(self): @@ -505,10 +506,10 @@ class TestUOpGraph(unittest.TestCase): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0) v = Variable("v", 0, 20) - st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v, v<16), UOp.const(dtypes.int, 0))) + st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v.valid(v<16)), UOp.const(dtypes.int, 0))) to_uops_list([st0]) - st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), v, v<20)) + st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v.valid(v<20)), v)) with self.assertRaises(RuntimeError): to_uops_list([st1]) @unittest.skip("if not allowed in graph") @@ -541,7 +542,7 @@ class TestUOpGraph(unittest.TestCase): ridx = UOp.range(20, 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16))),)) to_uops_list([ld0]) glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0) ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),)) @@ -552,7 +553,7 @@ class TestUOpGraph(unittest.TestCase): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0) ridx = UOp.range(20, 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not())),)) to_uops_list([ld0]) @unittest.skip("Bool load is not supported yet") @@ -574,23 +575,23 @@ class TestUOpGraph(unittest.TestCase): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) gidx0 = UOp.range(42, 0, AxisType.GLOBAL) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5=0)&(ld0<32)),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8)),)).cast(dtypes.index) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32))),)) to_uops_list([ld1]) - ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),)) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64))),)) with self.assertRaises(RuntimeError): to_uops_list([ld1]) def test_bounds_with_loaded_bool(self): @@ -620,7 +621,7 @@ class TestUOpGraph(unittest.TestCase): glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2) idx = UOp.const(dtypes.int, 0) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),)) - ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx, UOp.const(dtypes.bool, True)),)) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx.valid(UOp.const(dtypes.bool, True))),)) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))]) ld0 = uops[-1].src[-1] # the gate and invalid value are deleted from ld1 @@ -633,7 +634,7 @@ class TestUOpGraph(unittest.TestCase): st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int))) barrier = UOp(Ops.BARRIER, dtypes.void, (st, )) ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),)) - ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(lidx+2, UOp.const(dtypes.bool, True)),)) + ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))),)) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))]) ld0 = uops[-1].src[-1] @@ -646,7 +647,7 @@ class TestUOpGraph(unittest.TestCase): idx1 = UOp.const(dtypes.int, 0) val = UOp.const(dtypes.int, 42) st0 = glbl.index(UOp.invalid()).store(val) - st1 = glbl.index(idx0, UOp.const(dtypes.bool, True)).store(val) + st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True))).store(val) uops = to_uops_list([st0, st1]) # only the second store happens self.assertEqual(len(uops), 5) diff --git a/test/test_uops.py b/test/test_uops.py index 5bf49bd9a4..375c166124 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -277,7 +277,7 @@ class TestGatedStoreRewrite(unittest.TestCase): gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0') gate = gidx0 size[-2] else b.transpose(-2, -1), tolerance=1e-3) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index e97db78ef5..72a81f2019 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -39,6 +39,7 @@ shared_spec = PatternMatcher([ (UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), + (UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None), ]) # ***** UOp spec in the Tensor graph ***** @@ -105,9 +106,9 @@ tensor_spec = PatternMatcher([ (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), ])+shared_spec -# ***** UOp spec in linearized programs ***** +# ***** UOp spec in codegen shared between kernel and program ***** -program_spec = PatternMatcher([ +shared_codegen_spec = PatternMatcher([ # DEFINEs (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL), @@ -117,42 +118,59 @@ program_spec = PatternMatcher([ (UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True), (UPat(Ops.GROUP, dtypes.void), lambda: True), - # INDEX is used in new style load/store - (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), - (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True), - - # LOAD (idx, alt_value) / STORE(if gated) / LOAD(idx) / STORE(idx, val) - (UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().load(UPat()), validate_index), - (UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().store(UPat()), validate_index), - (UPat().index(UPat(), name="idx").or_casted().load(), validate_index), - (UPat().index(UPat(), name="idx").or_casted().store(UPat()), validate_index), - # RANGE/SPECIAL define loops, END closes them - (UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)), (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True), - # make sure all index dtypes have been lowered - (UPat(GroupOp.All, dtype=dtypes.index), lambda: False), - (UPat(Ops.CONST, arg=Invalid), lambda: False), - (UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)), - # WMMA has a (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8), - # if has a - (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True), - (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), + # UNROLL/CONTRACT is used here for WMMA + (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), + (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), # VECTORIZE/GEP (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), (UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), - # BARRIER - (UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True), + # LOAD(idx) / STORE(idx, val) / LOAD with alt value only exists in program_spec + (UPat().index(UPat()).or_casted().load(), lambda: True), + (UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True), # all CUSTOM + PRECAST (UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True), -])+shared_spec + + # INDEX + (UPat(GroupOp.Defines, name="buf").or_after().index(UPat.var("idx")), validate_index), + + # SPECIAL + (UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)), + + # BARRIER + (UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True), +]) + +# ***** UOp spec in linearized programs ***** + +program_spec = PatternMatcher([ + # INDEX with a gate as third src + (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines, name="buf").or_after(), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index), + + # LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate + (UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True), + + # END closes ranges + (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True), + + # make sure all index dtypes have been lowered + (UPat(GroupOp.All, dtype=dtypes.index), lambda: False), + (UPat(Ops.CONST, arg=Invalid), lambda: False), + (UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and + type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), + + # if has a + (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True), + (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), +])+shared_codegen_spec+shared_spec # ***** UOp spec in kernel graph ***** @@ -160,14 +178,6 @@ kernel_spec = PatternMatcher([ # index is allowed here (UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True), - # LOAD(idx) / STORE(idx, val) -- NOTE: we do this here to not run validate_index since z3 doesn't support Invalid - (UPat(Ops.INDEX).or_casted().load(), lambda: True), - (UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True), - - # UNROLL/CONTRACT is used here for WMMA - (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), - (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), - # END can end multiple axes here (UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True), @@ -176,10 +186,7 @@ kernel_spec = PatternMatcher([ # reduce must be on ranges (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), - - # intermediate index - (UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None), -])+program_spec+shared_spec +])+shared_codegen_spec+shared_spec # *** this spec should match all UOps ever created *** diff --git a/tinygrad/uop/validate.py b/tinygrad/uop/validate.py index 0e134c26ee..f379b20922 100644 --- a/tinygrad/uop/validate.py +++ b/tinygrad/uop/validate.py @@ -1,6 +1,6 @@ from typing import Callable from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, python_alu, graph_rewrite -from tinygrad.dtype import ImageDType, dtypes +from tinygrad.dtype import ImageDType, dtypes, Invalid from tinygrad.helpers import IGNORE_OOB, Context, cpu_profile try: @@ -25,15 +25,19 @@ try: # ctx is (solver, load_number_dict) # each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different # contexts can have the same hash but error on comparison + def add_valid(ctx, cond, x): + ctx[0].add(cond.arg[1]) + return x z3_renderer = PatternMatcher([ + (UPat(Ops.NOOP, name="cond").where(UPat(Ops.NOOP, name="x"), UPat(Ops.CONST, arg=Invalid)), add_valid), (UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))), (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))), (UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))), # loaded bools become a z3 int with min max of 0-1 (UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)), - (UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), - lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))), + (UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), lambda x,ctx: + UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))) if x.arg is not Invalid else None), # z3 can cast from bool to int automatically (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), (UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))), @@ -55,24 +59,26 @@ try: z3_imported = True except (ImportError, AttributeError): z3_imported = False -def validate_index(idx:UOp, gate:UOp|None=None): +def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None): + if idx.op is Ops.CONST and idx.arg is Invalid: return True if gate is None: gate = UOp.const(dtypes.bool, True) # TODO: check for overflow - if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True + if IGNORE_OOB or isinstance(buf.dtype, ImageDType) or (sz := buf.ptrdtype.size) == -1: return True # We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask - if 0<=idx.src[1].vmin and idx.src[1].vmax= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"") solver = z3.Solver(ctx=z3.Context()) - z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], gate) + z3_idx, z3_mask = uops_to_z3(solver, idx, gate) solver.add(z3_mask) with cpu_profile("validate index with z3", "TINY"): - if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat: - print(f"idx={idx.src[1].render(simplify=False)}") - print(f"gate={gate.render(simplify=False)}") - print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}") - return False - return True + match solver.check((z3_idx<0)|(sz<=z3_idx)): + case z3.unsat: return True + case z3.sat: print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}") + case z3.unknown: print(f"# UNKNOWN RESULT FROM Z3: {solver.reason_unknown()}\nconstraints = {solver}") + print(f"idx={idx.render(simplify=False)}") + print(f"mask={gate.render(simplify=False)}") + return False