mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
shared_codegen_spec and fix index spec (#12967)
* split shared_codegen_spec and fix index * add VCONST to program_spec and move index to shared_codegen_spec * working ignore_oob=0 * cleanup * fix spec * undo that * move barrier and special earlier * fix more spec issues * more updates * remove special from program_spec * cleanup and fixes * move more to shared * special is not in shared_spec * some comments * dont do bounds check there
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import unittest, itertools, math
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.dtype import DType, ConstType
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.codegen import full_rewrite_to_sink
|
||||
@@ -126,7 +126,8 @@ class TestBitcastConstFolding(unittest.TestCase):
|
||||
t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233})
|
||||
|
||||
def test_vec_bitcast(self):
|
||||
r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
|
||||
with Context(SPEC=0):
|
||||
r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
|
||||
self.assertEqual(r.op, Ops.VECTORIZE)
|
||||
self.assertEqual(r.dtype, dtypes.uint32.vec(3))
|
||||
self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75))
|
||||
|
||||
@@ -46,7 +46,7 @@ class TestRendererFailures(unittest.TestCase):
|
||||
def test_gated_store_with_alu(self):
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1)))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
@@ -57,7 +57,7 @@ class TestRendererFailures(unittest.TestCase):
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1)))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0]
|
||||
|
||||
@@ -307,9 +307,10 @@ class TestUOpGraph(unittest.TestCase):
|
||||
for vec_size in [2, 4, 8]:
|
||||
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
||||
uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
|
||||
for uop, const in zip(uops, consts):
|
||||
self.assertEqual(uop, const)
|
||||
with Context(SPEC=0):
|
||||
uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
|
||||
for uop, const in zip(uops, consts):
|
||||
self.assertEqual(uop, const)
|
||||
|
||||
@unittest.skip("no longer testable standalone")
|
||||
def test_wmma_vectorize_fold(self):
|
||||
@@ -505,10 +506,10 @@ class TestUOpGraph(unittest.TestCase):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
|
||||
v = Variable("v", 0, 20)
|
||||
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v, v<16), UOp.const(dtypes.int, 0)))
|
||||
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v.valid(v<16)), UOp.const(dtypes.int, 0)))
|
||||
to_uops_list([st0])
|
||||
|
||||
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), v, v<20))
|
||||
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v.valid(v<20)), v))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([st1])
|
||||
|
||||
@unittest.skip("if not allowed in graph")
|
||||
@@ -541,7 +542,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
ridx = UOp.range(20, 0)
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16))),))
|
||||
to_uops_list([ld0])
|
||||
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
|
||||
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
|
||||
@@ -552,7 +553,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not())),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
@unittest.skip("Bool load is not supported yet")
|
||||
@@ -574,23 +575,23 @@ class TestUOpGraph(unittest.TestCase):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5<gidx0)&(gidx0<16)),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<16),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16))),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16)),))
|
||||
to_uops_list([ld0, ld1])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<17),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17)),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_symbolic_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
i = Variable("i", 1, 80)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<10),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10)),))
|
||||
to_uops_list([ld0])
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<15),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15)),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<20),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20)),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_index_load(self):
|
||||
@@ -598,11 +599,11 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8)),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32))),))
|
||||
to_uops_list([ld1])
|
||||
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64))),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
||||
|
||||
def test_bounds_with_loaded_bool(self):
|
||||
@@ -620,7 +621,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx, UOp.const(dtypes.bool, True)),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx.valid(UOp.const(dtypes.bool, True))),))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
@@ -633,7 +634,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(lidx+2, UOp.const(dtypes.bool, True)),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))),))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
|
||||
|
||||
ld0 = uops[-1].src[-1]
|
||||
@@ -646,7 +647,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
idx1 = UOp.const(dtypes.int, 0)
|
||||
val = UOp.const(dtypes.int, 42)
|
||||
st0 = glbl.index(UOp.invalid()).store(val)
|
||||
st1 = glbl.index(idx0, UOp.const(dtypes.bool, True)).store(val)
|
||||
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True))).store(val)
|
||||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 5)
|
||||
|
||||
@@ -277,7 +277,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2), gate))
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
store = UOp(Ops.STORE, dtypes.void, (idx, val))
|
||||
uops = to_uops_list([store])
|
||||
@@ -294,7 +294,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gidx0<UOp.const(dtypes.int, 1)))
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest, functools
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, Context
|
||||
import numpy as np
|
||||
|
||||
def orthogonality_helper(A:Tensor, tolerance=1e-5):
|
||||
@@ -27,15 +27,16 @@ class TestLinAlg(unittest.TestCase):
|
||||
reconstruction_helper([U,s_diag,V],a)
|
||||
|
||||
def _test_svd_nonfull(self, size):
|
||||
a = Tensor.randn(size).realize()
|
||||
U,S,V = a.svd(full_matrices=False)
|
||||
b_shape,m,n = size[0:-2],size[-2],size[-1]
|
||||
k = min(m,n)
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
|
||||
#reduced U,V is only orthogonal along smaller dim
|
||||
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
|
||||
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
|
||||
reconstruction_helper([U,s_diag,V],a)
|
||||
with Context(IGNORE_OOB=1): # sometimes this is slow in CI
|
||||
a = Tensor.randn(size).realize()
|
||||
U,S,V = a.svd(full_matrices=False)
|
||||
b_shape,m,n = size[0:-2],size[-2],size[-1]
|
||||
k = min(m,n)
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
|
||||
#reduced U,V is only orthogonal along smaller dim
|
||||
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
|
||||
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
|
||||
reconstruction_helper([U,s_diag,V],a)
|
||||
|
||||
# faster for parallel pytest
|
||||
def test_svd_nonfull_2_2(self): self._test_svd_nonfull((2,2))
|
||||
@@ -75,4 +76,4 @@ class TestLinAlg(unittest.TestCase):
|
||||
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -39,6 +39,7 @@ shared_spec = PatternMatcher([
|
||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
|
||||
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
||||
])
|
||||
|
||||
# ***** UOp spec in the Tensor graph *****
|
||||
@@ -105,9 +106,9 @@ tensor_spec = PatternMatcher([
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||
])+shared_spec
|
||||
|
||||
# ***** UOp spec in linearized programs *****
|
||||
# ***** UOp spec in codegen shared between kernel and program *****
|
||||
|
||||
program_spec = PatternMatcher([
|
||||
shared_codegen_spec = PatternMatcher([
|
||||
# DEFINEs
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
|
||||
@@ -117,42 +118,59 @@ program_spec = PatternMatcher([
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.GROUP, dtypes.void), lambda: True),
|
||||
|
||||
# INDEX is used in new style load/store
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
|
||||
|
||||
# LOAD (idx, alt_value) / STORE(if gated) / LOAD(idx) / STORE(idx, val)
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().load(UPat()), validate_index),
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().store(UPat()), validate_index),
|
||||
(UPat().index(UPat(), name="idx").or_casted().load(), validate_index),
|
||||
(UPat().index(UPat(), name="idx").or_casted().store(UPat()), validate_index),
|
||||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
|
||||
# make sure all index dtypes have been lowered
|
||||
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
|
||||
(UPat(Ops.CONST, arg=Invalid), lambda: False),
|
||||
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)),
|
||||
|
||||
# WMMA has a <a, b, acc>
|
||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
||||
|
||||
# if has a <gate, index_for_dedup>
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
|
||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||
# UNROLL/CONTRACT is used here for WMMA
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# VECTORIZE/GEP
|
||||
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
|
||||
# BARRIER
|
||||
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True),
|
||||
# LOAD(idx) / STORE(idx, val) / LOAD with alt value only exists in program_spec
|
||||
(UPat().index(UPat()).or_casted().load(), lambda: True),
|
||||
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
|
||||
|
||||
# all CUSTOM + PRECAST
|
||||
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
||||
])+shared_spec
|
||||
|
||||
# INDEX
|
||||
(UPat(GroupOp.Defines, name="buf").or_after().index(UPat.var("idx")), validate_index),
|
||||
|
||||
# SPECIAL
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
|
||||
|
||||
# BARRIER
|
||||
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True),
|
||||
])
|
||||
|
||||
# ***** UOp spec in linearized programs *****
|
||||
|
||||
program_spec = PatternMatcher([
|
||||
# INDEX with a gate as third src
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines, name="buf").or_after(), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
|
||||
|
||||
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
|
||||
|
||||
# END closes ranges
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
|
||||
# make sure all index dtypes have been lowered
|
||||
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
|
||||
(UPat(Ops.CONST, arg=Invalid), lambda: False),
|
||||
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and
|
||||
type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
|
||||
# if has a <gate, index_for_dedup>
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
|
||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||
])+shared_codegen_spec+shared_spec
|
||||
|
||||
# ***** UOp spec in kernel graph *****
|
||||
|
||||
@@ -160,14 +178,6 @@ kernel_spec = PatternMatcher([
|
||||
# index is allowed here
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# LOAD(idx) / STORE(idx, val) -- NOTE: we do this here to not run validate_index since z3 doesn't support Invalid
|
||||
(UPat(Ops.INDEX).or_casted().load(), lambda: True),
|
||||
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
|
||||
|
||||
# UNROLL/CONTRACT is used here for WMMA
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# END can end multiple axes here
|
||||
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
|
||||
|
||||
@@ -176,10 +186,7 @@ kernel_spec = PatternMatcher([
|
||||
|
||||
# reduce must be on ranges
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
|
||||
# intermediate index
|
||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
||||
])+program_spec+shared_spec
|
||||
])+shared_codegen_spec+shared_spec
|
||||
|
||||
# *** this spec should match all UOps ever created ***
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Callable
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, python_alu, graph_rewrite
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.dtype import ImageDType, dtypes, Invalid
|
||||
from tinygrad.helpers import IGNORE_OOB, Context, cpu_profile
|
||||
|
||||
try:
|
||||
@@ -25,15 +25,19 @@ try:
|
||||
# ctx is (solver, load_number_dict)
|
||||
# each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different
|
||||
# contexts can have the same hash but error on comparison
|
||||
def add_valid(ctx, cond, x):
|
||||
ctx[0].add(cond.arg[1])
|
||||
return x
|
||||
z3_renderer = PatternMatcher([
|
||||
(UPat(Ops.NOOP, name="cond").where(UPat(Ops.NOOP, name="x"), UPat(Ops.CONST, arg=Invalid)), add_valid),
|
||||
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))),
|
||||
# loaded bools become a z3 int with min max of 0-1
|
||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)),
|
||||
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"),
|
||||
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))),
|
||||
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))) if x.arg is not Invalid else None),
|
||||
# z3 can cast from bool to int automatically
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
|
||||
@@ -55,24 +59,26 @@ try:
|
||||
z3_imported = True
|
||||
except (ImportError, AttributeError): z3_imported = False
|
||||
|
||||
def validate_index(idx:UOp, gate:UOp|None=None):
|
||||
def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
|
||||
if idx.op is Ops.CONST and idx.arg is Invalid: return True
|
||||
if gate is None: gate = UOp.const(dtypes.bool, True)
|
||||
# TODO: check for overflow
|
||||
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
|
||||
if IGNORE_OOB or isinstance(buf.dtype, ImageDType) or (sz := buf.ptrdtype.size) == -1: return True
|
||||
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
|
||||
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
|
||||
if 0<=idx.vmin and idx.vmax<sz: return True
|
||||
|
||||
# WEBGPU has a BITCAST in the index. TODO: fix
|
||||
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
|
||||
|
||||
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
|
||||
solver = z3.Solver(ctx=z3.Context())
|
||||
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], gate)
|
||||
z3_idx, z3_mask = uops_to_z3(solver, idx, gate)
|
||||
solver.add(z3_mask)
|
||||
with cpu_profile("validate index with z3", "TINY"):
|
||||
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
|
||||
print(f"idx={idx.src[1].render(simplify=False)}")
|
||||
print(f"gate={gate.render(simplify=False)}")
|
||||
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
|
||||
return False
|
||||
return True
|
||||
match solver.check((z3_idx<0)|(sz<=z3_idx)):
|
||||
case z3.unsat: return True
|
||||
case z3.sat: print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
|
||||
case z3.unknown: print(f"# UNKNOWN RESULT FROM Z3: {solver.reason_unknown()}\nconstraints = {solver}")
|
||||
print(f"idx={idx.render(simplify=False)}")
|
||||
print(f"mask={gate.render(simplify=False)}")
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user