shared_codegen_spec and fix index spec (#12967)

* split shared_codegen_spec and fix index

* add VCONST to program_spec and move index to shared_codegen_spec

* working ignore_oob=0

* cleanup

* fix spec

* undo that

* move barrier and special earlier

* fix more spec issues

* more updates

* remove special from program_spec

* cleanup and fixes

* move more to shared

* special is not in shared_spec

* some comments

* dont do bounds check there
This commit is contained in:
Sieds Lykles
2025-10-29 09:14:11 +01:00
committed by GitHub
parent 1c362736aa
commit 9f39f6391c
7 changed files with 101 additions and 85 deletions

View File

@@ -1,5 +1,5 @@
import unittest, itertools, math
from tinygrad import Tensor, Device, dtypes
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.dtype import DType, ConstType
from tinygrad.uop.ops import Ops, UOp
from tinygrad.codegen import full_rewrite_to_sink
@@ -126,7 +126,8 @@ class TestBitcastConstFolding(unittest.TestCase):
t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233})
def test_vec_bitcast(self):
r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
with Context(SPEC=0):
r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
self.assertEqual(r.op, Ops.VECTORIZE)
self.assertEqual(r.dtype, dtypes.uint32.vec(3))
self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75))

View File

@@ -46,7 +46,7 @@ class TestRendererFailures(unittest.TestCase):
def test_gated_store_with_alu(self):
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1)))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
@@ -57,7 +57,7 @@ class TestRendererFailures(unittest.TestCase):
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1)))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0]

View File

@@ -307,9 +307,10 @@ class TestUOpGraph(unittest.TestCase):
for vec_size in [2, 4, 8]:
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
for uop, const in zip(uops, consts):
self.assertEqual(uop, const)
with Context(SPEC=0):
uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
for uop, const in zip(uops, consts):
self.assertEqual(uop, const)
@unittest.skip("no longer testable standalone")
def test_wmma_vectorize_fold(self):
@@ -505,10 +506,10 @@ class TestUOpGraph(unittest.TestCase):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
v = Variable("v", 0, 20)
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v, v<16), UOp.const(dtypes.int, 0)))
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v.valid(v<16)), UOp.const(dtypes.int, 0)))
to_uops_list([st0])
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), v, v<20))
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v.valid(v<20)), v))
with self.assertRaises(RuntimeError): to_uops_list([st1])
@unittest.skip("if not allowed in graph")
@@ -541,7 +542,7 @@ class TestUOpGraph(unittest.TestCase):
ridx = UOp.range(20, 0)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16))),))
to_uops_list([ld0])
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
@@ -552,7 +553,7 @@ class TestUOpGraph(unittest.TestCase):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not())),))
to_uops_list([ld0])
@unittest.skip("Bool load is not supported yet")
@@ -574,23 +575,23 @@ class TestUOpGraph(unittest.TestCase):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5<gidx0)&(gidx0<16)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<16),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16))),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16)),))
to_uops_list([ld0, ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<17),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17)),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = Variable("i", 1, 80)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<10),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10)),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<15),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15)),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<20),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20)),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_index_load(self):
@@ -598,11 +599,11 @@ class TestUOpGraph(unittest.TestCase):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8)),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32))),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64))),))
with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_bounds_with_loaded_bool(self):
@@ -620,7 +621,7 @@ class TestUOpGraph(unittest.TestCase):
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx, UOp.const(dtypes.bool, True)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx.valid(UOp.const(dtypes.bool, True))),))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
@@ -633,7 +634,7 @@ class TestUOpGraph(unittest.TestCase):
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(lidx+2, UOp.const(dtypes.bool, True)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))),))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
ld0 = uops[-1].src[-1]
@@ -646,7 +647,7 @@ class TestUOpGraph(unittest.TestCase):
idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42)
st0 = glbl.index(UOp.invalid()).store(val)
st1 = glbl.index(idx0, UOp.const(dtypes.bool, True)).store(val)
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True))).store(val)
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 5)

View File

@@ -277,7 +277,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
gate = gidx0<UOp.const(dtypes.int, 1)
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2), gate))
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
val = UOp.const(dtypes.float, 42.0)
store = UOp(Ops.STORE, dtypes.void, (idx, val))
uops = to_uops_list([store])
@@ -294,7 +294,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
idx = gidx0 * UOp.const(dtypes.int, 2)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gidx0<UOp.const(dtypes.int, 1)))
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
val = UOp.const(dtypes.float, 42.0)
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]

View File

@@ -1,5 +1,5 @@
import unittest, functools
from tinygrad import Tensor
from tinygrad import Tensor, Context
import numpy as np
def orthogonality_helper(A:Tensor, tolerance=1e-5):
@@ -27,15 +27,16 @@ class TestLinAlg(unittest.TestCase):
reconstruction_helper([U,s_diag,V],a)
def _test_svd_nonfull(self, size):
a = Tensor.randn(size).realize()
U,S,V = a.svd(full_matrices=False)
b_shape,m,n = size[0:-2],size[-2],size[-1]
k = min(m,n)
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
#reduced U,V is only orthogonal along smaller dim
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
reconstruction_helper([U,s_diag,V],a)
with Context(IGNORE_OOB=1): # sometimes this is slow in CI
a = Tensor.randn(size).realize()
U,S,V = a.svd(full_matrices=False)
b_shape,m,n = size[0:-2],size[-2],size[-1]
k = min(m,n)
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
#reduced U,V is only orthogonal along smaller dim
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
reconstruction_helper([U,s_diag,V],a)
# faster for parallel pytest
def test_svd_nonfull_2_2(self): self._test_svd_nonfull((2,2))
@@ -75,4 +76,4 @@ class TestLinAlg(unittest.TestCase):
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-3)
if __name__ == "__main__":
unittest.main()
unittest.main()

View File

@@ -39,6 +39,7 @@ shared_spec = PatternMatcher([
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
])
# ***** UOp spec in the Tensor graph *****
@@ -105,9 +106,9 @@ tensor_spec = PatternMatcher([
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
])+shared_spec
# ***** UOp spec in linearized programs *****
# ***** UOp spec in codegen shared between kernel and program *****
program_spec = PatternMatcher([
shared_codegen_spec = PatternMatcher([
# DEFINEs
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
@@ -117,42 +118,59 @@ program_spec = PatternMatcher([
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
(UPat(Ops.GROUP, dtypes.void), lambda: True),
# INDEX is used in new style load/store
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
# LOAD (idx, alt_value) / STORE(if gated) / LOAD(idx) / STORE(idx, val)
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().load(UPat()), validate_index),
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().store(UPat()), validate_index),
(UPat().index(UPat(), name="idx").or_casted().load(), validate_index),
(UPat().index(UPat(), name="idx").or_casted().store(UPat()), validate_index),
# RANGE/SPECIAL define loops, END closes them
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
# make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
(UPat(Ops.CONST, arg=Invalid), lambda: False),
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)),
# WMMA has a <a, b, acc>
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
# if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
# UNROLL/CONTRACT is used here for WMMA
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# VECTORIZE/GEP
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
# BARRIER
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True),
# LOAD(idx) / STORE(idx, val) / LOAD with alt value only exists in program_spec
(UPat().index(UPat()).or_casted().load(), lambda: True),
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
# all CUSTOM + PRECAST
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
])+shared_spec
# INDEX
(UPat(GroupOp.Defines, name="buf").or_after().index(UPat.var("idx")), validate_index),
# SPECIAL
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
# BARRIER
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True),
])
# ***** UOp spec in linearized programs *****
program_spec = PatternMatcher([
# INDEX with a gate as third src
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines, name="buf").or_after(), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
# END closes ranges
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
# make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
(UPat(Ops.CONST, arg=Invalid), lambda: False),
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and
type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
# if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
])+shared_codegen_spec+shared_spec
# ***** UOp spec in kernel graph *****
@@ -160,14 +178,6 @@ kernel_spec = PatternMatcher([
# index is allowed here
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
# LOAD(idx) / STORE(idx, val) -- NOTE: we do this here to not run validate_index since z3 doesn't support Invalid
(UPat(Ops.INDEX).or_casted().load(), lambda: True),
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
# UNROLL/CONTRACT is used here for WMMA
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# END can end multiple axes here
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
@@ -176,10 +186,7 @@ kernel_spec = PatternMatcher([
# reduce must be on ranges
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
# intermediate index
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
])+program_spec+shared_spec
])+shared_codegen_spec+shared_spec
# *** this spec should match all UOps ever created ***

View File

@@ -1,6 +1,6 @@
from typing import Callable
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, python_alu, graph_rewrite
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.dtype import ImageDType, dtypes, Invalid
from tinygrad.helpers import IGNORE_OOB, Context, cpu_profile
try:
@@ -25,15 +25,19 @@ try:
# ctx is (solver, load_number_dict)
# each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different
# contexts can have the same hash but error on comparison
def add_valid(ctx, cond, x):
ctx[0].add(cond.arg[1])
return x
z3_renderer = PatternMatcher([
(UPat(Ops.NOOP, name="cond").where(UPat(Ops.NOOP, name="x"), UPat(Ops.CONST, arg=Invalid)), add_valid),
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))),
# loaded bools become a z3 int with min max of 0-1
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)),
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"),
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))),
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))) if x.arg is not Invalid else None),
# z3 can cast from bool to int automatically
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
@@ -55,24 +59,26 @@ try:
z3_imported = True
except (ImportError, AttributeError): z3_imported = False
def validate_index(idx:UOp, gate:UOp|None=None):
def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
if idx.op is Ops.CONST and idx.arg is Invalid: return True
if gate is None: gate = UOp.const(dtypes.bool, True)
# TODO: check for overflow
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
if IGNORE_OOB or isinstance(buf.dtype, ImageDType) or (sz := buf.ptrdtype.size) == -1: return True
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
if 0<=idx.vmin and idx.vmax<sz: return True
# WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
solver = z3.Solver(ctx=z3.Context())
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], gate)
z3_idx, z3_mask = uops_to_z3(solver, idx, gate)
solver.add(z3_mask)
with cpu_profile("validate index with z3", "TINY"):
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
print(f"idx={idx.src[1].render(simplify=False)}")
print(f"gate={gate.render(simplify=False)}")
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
return False
return True
match solver.check((z3_idx<0)|(sz<=z3_idx)):
case z3.unsat: return True
case z3.sat: print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
case z3.unknown: print(f"# UNKNOWN RESULT FROM Z3: {solver.reason_unknown()}\nconstraints = {solver}")
print(f"idx={idx.render(simplify=False)}")
print(f"mask={gate.render(simplify=False)}")
return False