shared_codegen_spec and fix index spec (#12967)

* split shared_codegen_spec and fix index

* add VCONST to program_spec and move index to shared_codegen_spec

* working ignore_oob=0

* cleanup

* fix spec

* undo that

* move barrier and special earlier

* fix more spec issues

* more updates

* remove special from program_spec

* cleanup and fixes

* move more to shared

* special is not in shared_spec

* some comments

* dont do bounds check there
This commit is contained in:
Sieds Lykles
2025-10-29 09:14:11 +01:00
committed by GitHub
parent 1c362736aa
commit 9f39f6391c
7 changed files with 101 additions and 85 deletions

View File

@@ -1,5 +1,5 @@
import unittest, itertools, math import unittest, itertools, math
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.dtype import DType, ConstType from tinygrad.dtype import DType, ConstType
from tinygrad.uop.ops import Ops, UOp from tinygrad.uop.ops import Ops, UOp
from tinygrad.codegen import full_rewrite_to_sink from tinygrad.codegen import full_rewrite_to_sink
@@ -126,7 +126,8 @@ class TestBitcastConstFolding(unittest.TestCase):
t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233}) t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233})
def test_vec_bitcast(self): def test_vec_bitcast(self):
r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0] with Context(SPEC=0):
r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
self.assertEqual(r.op, Ops.VECTORIZE) self.assertEqual(r.op, Ops.VECTORIZE)
self.assertEqual(r.dtype, dtypes.uint32.vec(3)) self.assertEqual(r.dtype, dtypes.uint32.vec(3))
self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75)) self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75))

View File

@@ -46,7 +46,7 @@ class TestRendererFailures(unittest.TestCase):
def test_gated_store_with_alu(self): def test_gated_store_with_alu(self):
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1))) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
@@ -57,7 +57,7 @@ class TestRendererFailures(unittest.TestCase):
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0) gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1))) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0] ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0]

View File

@@ -307,9 +307,10 @@ class TestUOpGraph(unittest.TestCase):
for vec_size in [2, 4, 8]: for vec_size in [2, 4, 8]:
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)] consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts)) vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)]) with Context(SPEC=0):
for uop, const in zip(uops, consts): uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
self.assertEqual(uop, const) for uop, const in zip(uops, consts):
self.assertEqual(uop, const)
@unittest.skip("no longer testable standalone") @unittest.skip("no longer testable standalone")
def test_wmma_vectorize_fold(self): def test_wmma_vectorize_fold(self):
@@ -505,10 +506,10 @@ class TestUOpGraph(unittest.TestCase):
with Context(IGNORE_OOB=0): with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
v = Variable("v", 0, 20) v = Variable("v", 0, 20)
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v, v<16), UOp.const(dtypes.int, 0))) st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v.valid(v<16)), UOp.const(dtypes.int, 0)))
to_uops_list([st0]) to_uops_list([st0])
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), v, v<20)) st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v.valid(v<20)), v))
with self.assertRaises(RuntimeError): to_uops_list([st1]) with self.assertRaises(RuntimeError): to_uops_list([st1])
@unittest.skip("if not allowed in graph") @unittest.skip("if not allowed in graph")
@@ -541,7 +542,7 @@ class TestUOpGraph(unittest.TestCase):
ridx = UOp.range(20, 0) ridx = UOp.range(20, 0)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int) i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),)) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16))),))
to_uops_list([ld0]) to_uops_list([ld0])
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0) glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),)) ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
@@ -552,7 +553,7 @@ class TestUOpGraph(unittest.TestCase):
with Context(IGNORE_OOB=0): with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
ridx = UOp.range(20, 0) ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),)) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not())),))
to_uops_list([ld0]) to_uops_list([ld0])
@unittest.skip("Bool load is not supported yet") @unittest.skip("Bool load is not supported yet")
@@ -574,23 +575,23 @@ class TestUOpGraph(unittest.TestCase):
with Context(IGNORE_OOB=0): with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL) gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5<gidx0)&(gidx0<16)),)) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16))),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<16),)) ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16)),))
to_uops_list([ld0, ld1]) to_uops_list([ld0, ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<17),)) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17)),))
with self.assertRaises(RuntimeError): to_uops_list([ld0]) with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic_mask(self): def test_in_out_of_bounds_access_symbolic_mask(self):
with Context(IGNORE_OOB=0): with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = Variable("i", 1, 80) i = Variable("i", 1, 80)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<10),)) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10)),))
to_uops_list([ld0]) to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<15),)) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15)),))
to_uops_list([ld0]) to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<20),)) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20)),))
with self.assertRaises(RuntimeError): to_uops_list([ld0]) with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_index_load(self): def test_in_out_of_bounds_access_index_load(self):
@@ -598,11 +599,11 @@ class TestUOpGraph(unittest.TestCase):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0) glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL) gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8)),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),)) ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32))),))
to_uops_list([ld1]) to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),)) ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64))),))
with self.assertRaises(RuntimeError): to_uops_list([ld1]) with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_bounds_with_loaded_bool(self): def test_bounds_with_loaded_bool(self):
@@ -620,7 +621,7 @@ class TestUOpGraph(unittest.TestCase):
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2) glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0) idx = UOp.const(dtypes.int, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),)) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx, UOp.const(dtypes.bool, True)),)) ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx.valid(UOp.const(dtypes.bool, True))),))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))]) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
ld0 = uops[-1].src[-1] ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1 # the gate and invalid value are deleted from ld1
@@ -633,7 +634,7 @@ class TestUOpGraph(unittest.TestCase):
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int))) st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, )) barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),)) ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(lidx+2, UOp.const(dtypes.bool, True)),)) ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))),))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))]) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
ld0 = uops[-1].src[-1] ld0 = uops[-1].src[-1]
@@ -646,7 +647,7 @@ class TestUOpGraph(unittest.TestCase):
idx1 = UOp.const(dtypes.int, 0) idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42) val = UOp.const(dtypes.int, 42)
st0 = glbl.index(UOp.invalid()).store(val) st0 = glbl.index(UOp.invalid()).store(val)
st1 = glbl.index(idx0, UOp.const(dtypes.bool, True)).store(val) st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True))).store(val)
uops = to_uops_list([st0, st1]) uops = to_uops_list([st0, st1])
# only the second store happens # only the second store happens
self.assertEqual(len(uops), 5) self.assertEqual(len(uops), 5)

View File

@@ -277,7 +277,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0') gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
gate = gidx0<UOp.const(dtypes.int, 1) gate = gidx0<UOp.const(dtypes.int, 1)
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2), gate)) idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
val = UOp.const(dtypes.float, 42.0) val = UOp.const(dtypes.float, 42.0)
store = UOp(Ops.STORE, dtypes.void, (idx, val)) store = UOp(Ops.STORE, dtypes.void, (idx, val))
uops = to_uops_list([store]) uops = to_uops_list([store])
@@ -294,7 +294,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0') gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
idx = gidx0 * UOp.const(dtypes.int, 2) idx = gidx0 * UOp.const(dtypes.int, 2)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gidx0<UOp.const(dtypes.int, 1))) idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx)) idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
val = UOp.const(dtypes.float, 42.0) val = UOp.const(dtypes.float, 42.0)
stores = [UOp.store(idx0, val), UOp.store(idx1, val)] stores = [UOp.store(idx0, val), UOp.store(idx1, val)]

View File

@@ -1,5 +1,5 @@
import unittest, functools import unittest, functools
from tinygrad import Tensor from tinygrad import Tensor, Context
import numpy as np import numpy as np
def orthogonality_helper(A:Tensor, tolerance=1e-5): def orthogonality_helper(A:Tensor, tolerance=1e-5):
@@ -27,15 +27,16 @@ class TestLinAlg(unittest.TestCase):
reconstruction_helper([U,s_diag,V],a) reconstruction_helper([U,s_diag,V],a)
def _test_svd_nonfull(self, size): def _test_svd_nonfull(self, size):
a = Tensor.randn(size).realize() with Context(IGNORE_OOB=1): # sometimes this is slow in CI
U,S,V = a.svd(full_matrices=False) a = Tensor.randn(size).realize()
b_shape,m,n = size[0:-2],size[-2],size[-1] U,S,V = a.svd(full_matrices=False)
k = min(m,n) b_shape,m,n = size[0:-2],size[-2],size[-1]
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k))) k = min(m,n)
#reduced U,V is only orthogonal along smaller dim s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
if (m < n): orthogonality_helper(U),orthogonality_helper(V) #reduced U,V is only orthogonal along smaller dim
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1)) if (m < n): orthogonality_helper(U),orthogonality_helper(V)
reconstruction_helper([U,s_diag,V],a) else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
reconstruction_helper([U,s_diag,V],a)
# faster for parallel pytest # faster for parallel pytest
def test_svd_nonfull_2_2(self): self._test_svd_nonfull((2,2)) def test_svd_nonfull_2_2(self): self._test_svd_nonfull((2,2))
@@ -75,4 +76,4 @@ class TestLinAlg(unittest.TestCase):
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-3) orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -39,6 +39,7 @@ shared_spec = PatternMatcher([
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x: (UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
]) ])
# ***** UOp spec in the Tensor graph ***** # ***** UOp spec in the Tensor graph *****
@@ -105,9 +106,9 @@ tensor_spec = PatternMatcher([
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
])+shared_spec ])+shared_spec
# ***** UOp spec in linearized programs ***** # ***** UOp spec in codegen shared between kernel and program *****
program_spec = PatternMatcher([ shared_codegen_spec = PatternMatcher([
# DEFINEs # DEFINEs
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL), (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
@@ -117,42 +118,59 @@ program_spec = PatternMatcher([
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True), (UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
(UPat(Ops.GROUP, dtypes.void), lambda: True), (UPat(Ops.GROUP, dtypes.void), lambda: True),
# INDEX is used in new style load/store
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
# LOAD (idx, alt_value) / STORE(if gated) / LOAD(idx) / STORE(idx, val)
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().load(UPat()), validate_index),
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().store(UPat()), validate_index),
(UPat().index(UPat(), name="idx").or_casted().load(), validate_index),
(UPat().index(UPat(), name="idx").or_casted().store(UPat()), validate_index),
# RANGE/SPECIAL define loops, END closes them # RANGE/SPECIAL define loops, END closes them
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True), (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
# make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
(UPat(Ops.CONST, arg=Invalid), lambda: False),
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)),
# WMMA has a <a, b, acc> # WMMA has a <a, b, acc>
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8), (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
# if has a <gate, index_for_dedup> # UNROLL/CONTRACT is used here for WMMA
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True), (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# VECTORIZE/GEP # VECTORIZE/GEP
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), (UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
# BARRIER # LOAD(idx) / STORE(idx, val) / LOAD with alt value only exists in program_spec
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True), (UPat().index(UPat()).or_casted().load(), lambda: True),
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
# all CUSTOM + PRECAST # all CUSTOM + PRECAST
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True), (UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
])+shared_spec
# INDEX
(UPat(GroupOp.Defines, name="buf").or_after().index(UPat.var("idx")), validate_index),
# SPECIAL
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
# BARRIER
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True),
])
# ***** UOp spec in linearized programs *****
program_spec = PatternMatcher([
# INDEX with a gate as third src
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines, name="buf").or_after(), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
# END closes ranges
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
# make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
(UPat(Ops.CONST, arg=Invalid), lambda: False),
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and
type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
# if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
])+shared_codegen_spec+shared_spec
# ***** UOp spec in kernel graph ***** # ***** UOp spec in kernel graph *****
@@ -160,14 +178,6 @@ kernel_spec = PatternMatcher([
# index is allowed here # index is allowed here
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True), (UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
# LOAD(idx) / STORE(idx, val) -- NOTE: we do this here to not run validate_index since z3 doesn't support Invalid
(UPat(Ops.INDEX).or_casted().load(), lambda: True),
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
# UNROLL/CONTRACT is used here for WMMA
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# END can end multiple axes here # END can end multiple axes here
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True), (UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
@@ -176,10 +186,7 @@ kernel_spec = PatternMatcher([
# reduce must be on ranges # reduce must be on ranges
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
])+shared_codegen_spec+shared_spec
# intermediate index
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
])+program_spec+shared_spec
# *** this spec should match all UOps ever created *** # *** this spec should match all UOps ever created ***

View File

@@ -1,6 +1,6 @@
from typing import Callable from typing import Callable
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, python_alu, graph_rewrite from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, python_alu, graph_rewrite
from tinygrad.dtype import ImageDType, dtypes from tinygrad.dtype import ImageDType, dtypes, Invalid
from tinygrad.helpers import IGNORE_OOB, Context, cpu_profile from tinygrad.helpers import IGNORE_OOB, Context, cpu_profile
try: try:
@@ -25,15 +25,19 @@ try:
# ctx is (solver, load_number_dict) # ctx is (solver, load_number_dict)
# each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different # each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different
# contexts can have the same hash but error on comparison # contexts can have the same hash but error on comparison
def add_valid(ctx, cond, x):
ctx[0].add(cond.arg[1])
return x
z3_renderer = PatternMatcher([ z3_renderer = PatternMatcher([
(UPat(Ops.NOOP, name="cond").where(UPat(Ops.NOOP, name="x"), UPat(Ops.CONST, arg=Invalid)), add_valid),
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))), (UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))), (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))), (UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))),
# loaded bools become a z3 int with min max of 0-1 # loaded bools become a z3 int with min max of 0-1
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx: (UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)), UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)),
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), (UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), lambda x,ctx:
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))), UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))) if x.arg is not Invalid else None),
# z3 can cast from bool to int automatically # z3 can cast from bool to int automatically
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))), (UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
@@ -55,24 +59,26 @@ try:
z3_imported = True z3_imported = True
except (ImportError, AttributeError): z3_imported = False except (ImportError, AttributeError): z3_imported = False
def validate_index(idx:UOp, gate:UOp|None=None): def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
if idx.op is Ops.CONST and idx.arg is Invalid: return True
if gate is None: gate = UOp.const(dtypes.bool, True) if gate is None: gate = UOp.const(dtypes.bool, True)
# TODO: check for overflow # TODO: check for overflow
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True if IGNORE_OOB or isinstance(buf.dtype, ImageDType) or (sz := buf.ptrdtype.size) == -1: return True
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask # We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True if 0<=idx.vmin and idx.vmax<sz: return True
# WEBGPU has a BITCAST in the index. TODO: fix # WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"") if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
solver = z3.Solver(ctx=z3.Context()) solver = z3.Solver(ctx=z3.Context())
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], gate) z3_idx, z3_mask = uops_to_z3(solver, idx, gate)
solver.add(z3_mask) solver.add(z3_mask)
with cpu_profile("validate index with z3", "TINY"): with cpu_profile("validate index with z3", "TINY"):
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat: match solver.check((z3_idx<0)|(sz<=z3_idx)):
print(f"idx={idx.src[1].render(simplify=False)}") case z3.unsat: return True
print(f"gate={gate.render(simplify=False)}") case z3.sat: print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}") case z3.unknown: print(f"# UNKNOWN RESULT FROM Z3: {solver.reason_unknown()}\nconstraints = {solver}")
return False print(f"idx={idx.render(simplify=False)}")
return True print(f"mask={gate.render(simplify=False)}")
return False