clean up the spec (#12868)

* tighten up the spec

* move validate into a different file

* that moved to validate

* after(barr)
This commit is contained in:
George Hotz
2025-10-22 19:50:42 +08:00
committed by GitHub
parent 726988fa4b
commit 7762b3558b
10 changed files with 166 additions and 165 deletions

View File

@@ -4,7 +4,7 @@ from tinygrad.helpers import Profiling, Timing, getenv
from tinygrad.uop.ops import Ops from tinygrad.uop.ops import Ops
from tinygrad.codegen import full_rewrite_to_sink from tinygrad.codegen import full_rewrite_to_sink
from tinygrad.codegen.late.control_flow import linearize from tinygrad.codegen.late.control_flow import linearize
from tinygrad.uop.spec import type_verify from tinygrad.uop.spec import type_verify, program_spec
if __name__ == "__main__": if __name__ == "__main__":
mdl = ResNet50() mdl = ResNet50()
@@ -41,5 +41,5 @@ if __name__ == "__main__":
for u in rewritten_uops: for u in rewritten_uops:
uops_line.append(linearize(u)) uops_line.append(linearize(u))
with Timing("***** model verify in "): with Timing("***** model verify in "):
for u in uops_line: type_verify(u) for u in uops_line: type_verify(u, program_spec)
print(sum(len(u) for u in uops_line)) print(sum(len(u) for u in uops_line))

View File

@@ -1,7 +1,7 @@
import random import random
import z3 import z3
from tinygrad import dtypes from tinygrad import dtypes
from tinygrad.uop.spec import uops_to_z3, z3_cdiv from tinygrad.uop.validate import uops_to_z3, z3_cdiv
from tinygrad.uop.ops import UOp from tinygrad.uop.ops import UOp
from tinygrad.uop.decompositions import fast_idiv from tinygrad.uop.decompositions import fast_idiv
random.seed(42) random.seed(42)

View File

@@ -2,7 +2,7 @@ import random, operator
import z3 import z3
from tinygrad import Variable, dtypes from tinygrad import Variable, dtypes
from tinygrad.uop.ops import UOp from tinygrad.uop.ops import UOp
from tinygrad.uop.spec import uops_to_z3 from tinygrad.uop.validate import uops_to_z3
from tinygrad.helpers import DEBUG, Context from tinygrad.helpers import DEBUG, Context
seed = random.randint(0, 100) seed = random.randint(0, 100)

View File

@@ -639,13 +639,13 @@ class TestUOpGraph(unittest.TestCase):
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0") lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int))) st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, )) barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(UOp.invalid()), barrier)) ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+2, UOp.const(dtypes.bool, True)), barrier)) ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(lidx+2, UOp.const(dtypes.bool, True)),))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))]) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
ld0 = uops[-1].src[-1] ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1 # the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.index(lidx+2)) self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2))
def test_fold_gated_store(self): def test_fold_gated_store(self):
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)

View File

@@ -6,7 +6,7 @@ from tinygrad.helpers import CI, DEBUG, getenv, Timing
from tinygrad.dtype import dtypes, DType, AddrSpace from tinygrad.dtype import dtypes, DType, AddrSpace
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401 from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401
from tinygrad.uop.spec import spec from tinygrad.uop.spec import shared_spec
from tinygrad.renderer import ProgramSpec from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.codegen import full_rewrite from tinygrad.codegen import full_rewrite
@@ -332,7 +332,7 @@ class TestLocalAccess(unittest.TestCase):
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem') smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0))) st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42) self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
# NOTE: webgpu specific, since only webgpu performs bitpacking # NOTE: webgpu specific, since only webgpu performs bitpacking
@@ -342,7 +342,7 @@ class TestLocalAccess(unittest.TestCase):
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem') smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42))) st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42) self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
# NOTE: webgpu specific, since only webgpu performs bitpacking # NOTE: webgpu specific, since only webgpu performs bitpacking
@@ -513,7 +513,7 @@ class TestUOpStr(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase): class TestUPatHelpers(unittest.TestCase):
def test_location(self): def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "symbolic.py") self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "symbolic.py")
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py") self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py")
test_upat = UPat(Ops.CONST, dtypes.bool) test_upat = UPat(Ops.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1]) self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
test_upat_named = test_upat.named("test_name") test_upat_named = test_upat.named("test_name")

View File

@@ -7,7 +7,7 @@ from tinygrad.codegen import full_rewrite
from tinygrad.helpers import Context from tinygrad.helpers import Context
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
from tinygrad.uop.symbolic import sym, commutative from tinygrad.uop.symbolic import sym, commutative
from tinygrad.uop.spec import uops_to_z3 from tinygrad.uop.validate import uops_to_z3
def check_uop_against_string(self, v:UOp, s:str): def check_uop_against_string(self, v:UOp, s:str):
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL)} sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL)}

View File

@@ -1,6 +1,6 @@
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
from tinygrad.uop.spec import type_verify from tinygrad.uop.spec import type_verify, program_spec
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
# import all pattern matchers here # import all pattern matchers here
@@ -98,5 +98,5 @@ def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
""" """
lst = linearize(full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)) lst = linearize(full_rewrite_to_sink(sink, ren, optimize=sink.tag is None))
if __debug__: type_verify(lst) if __debug__: type_verify(lst, program_spec)
return lst return lst

View File

@@ -11,7 +11,7 @@ from tinygrad.helpers import suppress_finalizing
from tinygrad.gradient import compute_gradient from tinygrad.gradient import compute_gradient
from tinygrad.uop.mathtraits import MathTrait from tinygrad.uop.mathtraits import MathTrait
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, srender from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, srender
from tinygrad.uop.spec import tensor_uop_spec, type_verify from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner from tinygrad.engine.memory import memory_planner
@@ -229,7 +229,7 @@ class Tensor(MathTrait):
big_sink = UOp.sink(*[x.uop for x in (self,)+lst]) big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
# verify Tensors match the spec # verify Tensors match the spec
if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec) if __debug__: type_verify(list(big_sink.toposort()), tensor_spec)
if any(isinstance(x._device, tuple) for x in big_sink.toposort()): if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map") _apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")

View File

@@ -1,60 +1,45 @@
from typing import cast, Callable from typing import cast
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, AxisType from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.helpers import all_same, prod, DEBUG, IGNORE_OOB, Context, cpu_profile from tinygrad.helpers import DEBUG, Context
try: from tinygrad.uop.validate import validate_index
import z3
# older versions of z3 dont have some operators like & overloaded
if z3.get_version() < (4, 12, 4, 0): raise ImportError
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND # four specs:
def z3_cdiv(a, b):return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b) # shared_spec -- usable anywhere
def z3_xor(a,b): # tensor_spec -- usable in tensor graph
if isinstance(a, z3.BoolRef): return a^b # program_spec -- usable in linearized program
assert a==-1 or b==-1, "xor can only be used in indexing if one of the aruments is -1" # full_spec -- all uops ever created
return -a-1 if b==-1 else -b-1
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor,
Ops.MAX: lambda a,b: z3.If(a<b, b, a), Ops.TRUNC: lambda a: a if a.is_int() else z3.ToReal(z3.If(a >= 0, z3.ToInt(a), -z3.ToInt(-a)))}
def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef:
s = z3.Int(name, ctx=solver.ctx)
solver.add(vmin <= s, s <= vmax)
return s
# ctx is (solver, load_number_dict) # *** these uops work anywhere ***
# each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different
# contexts can have the same hash but error on comparison
z3_renderer = PatternMatcher([
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))),
# loaded bools become a z3 int with min max of 0-1
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)),
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"),
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))),
# z3 can cast from bool to int automatically
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
# if the source of the cast is not a noop it means that it is a float and so we create a new variable
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))),
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))),
# A comparison between floats introduces a new bool variable
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
])
def uops_to_z3(solver, *uops: UOp) -> 'list[z3.ExprRef]': shared_spec = PatternMatcher([
with Context(TRACK_MATCH_STATS=0, SPEC=0): # cant pickle z3 objects, and these UOps don't follow spec (UPat(Ops.SINK, dtypes.void), lambda: True), # NOTE: for testing, we let sinks be anything
return [s.arg[1] for s in graph_rewrite(uops[0].sink(*uops[1:]), z3_renderer, ctx=(solver, {})).src]
z3_imported = True # CONST/DEFINE_VAR are everywhere
except (ImportError, AttributeError): z3_imported = False (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
buffer_spec = PatternMatcher([ # ALUs: most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
# and SHL/SHR, the shift distance can be an int
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
# CAST
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
# RANGE can be in the big graph now
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
])
# ***** UOp spec in the Tensor graph *****
tensor_spec = PatternMatcher([
# buffer spec
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d: (UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))), isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
@@ -63,9 +48,7 @@ buffer_spec = PatternMatcher([
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"), (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)), lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True), (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
])
assign_spec = PatternMatcher([
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER # KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
@@ -77,11 +60,7 @@ assign_spec = PatternMatcher([
# MSTACK combines buffers into multi # MSTACK combines buffers into multi
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)), (UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
])
# *** this is the spec of a Tensor in UOp ***
tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True), (UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True), (UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)), (UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
@@ -109,127 +88,69 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)), (UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)), (UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
# REDUCE_AXIS is the reduce in the tensor graph
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
# REDUCE with an outerworld range # REDUCE with an outerworld range
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
# AFTER if things were kernelized # AFTER if things were kernelized
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True) (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
]) ])+shared_spec
# ***** uop type spec ***** # ***** UOp spec in linearized programs *****
def validate_index(idx:UOp, gate:UOp|None=None): program_spec = PatternMatcher([
if gate is None: gate = UOp.const(dtypes.bool, True) # DEFINEs
# TODO: check for overflow
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
mask = idx.src[2]&gate if len(idx.src)==3 else gate
# WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
solver = z3.Solver(ctx=z3.Context())
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], mask)
solver.add(z3_mask)
with cpu_profile("validate index with z3", "TINY"):
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
print(f"idx={idx.src[1].render(simplify=False)}")
print(f"mask & gate={mask.render(simplify=False)}")
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
return False
return True
def validate_store(idx:UOp, val:UOp, gate:UOp|None=None):
if gate is None: gate = UOp.const(dtypes.bool, True)
if gate.op is Ops.IF: gate = gate.src[0]
# we need to find the implicit gates, inverse of delete_redundant_gates
for u in val.toposort():
if u.op is Ops.IF: gate &= u.src[0]
return validate_index(idx, gate)
index_pat = UPat(Ops.INDEX, name="idx").or_casted()
# this is the matcher for the final rendered UOps
# matcher functions returns True or False (or None to not match)
spec = PatternMatcher([
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL), (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
(UPat(Ops.DEFINE_REG, src=()), lambda: True), (UPat(Ops.DEFINE_REG, src=()), lambda: True),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
# allow AFTER on buffers # allow AFTER on buffers
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True), (UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
# **** new style load/store **** # INDEX is used in new style load/store
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
# LOAD (idx, alt_value) / LOAD(idx) / STORE(idx, val)
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").or_casted(), UPat((Ops.VECTORIZE, Ops.VCONST, Ops.CONST)))), validate_index),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").or_casted(), )), validate_index),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="idx").or_casted(), UPat())), validate_index),
# RANGE/SPECIAL define loops, END closes them
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
(UPat(Ops.END, src=(UPat(Ops.RANGE), UPat()), allow_any_len=True, arg=1, dtype=dtypes.void), lambda: True),
# make sure all index dtypes have been lowered # make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False), (UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
(UPat(Ops.CONST, arg=Invalid), lambda: False), (UPat(Ops.CONST, arg=Invalid), lambda: False),
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)), (UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)),
# INDEX is used in new style load/store
# INDEX takes a <buf, alu, gate?>
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
# LOAD takes a <bufidx, alt?, barrier?>
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
(UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
# STORE takes a <bufidx, val, ranges...>
(UPat(Ops.STORE, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
# and SHL/SHR, the shift distance can be an int
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
(UPat(Ops.END, dtype=dtypes.void), lambda: True),
# WMMA has a <a, b, acc> # WMMA has a <a, b, acc>
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8), (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# if has a <gate, barrier?> # if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),), allow_any_len=True), lambda: True), (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),), allow_any_len=True), lambda: True), (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), # VECTORIZE/GEP
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None), (UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
# BARRIER
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
(UPat(Ops.BARRIER, dtypes.void), lambda: True), # BARRIERs can also happen at the end of loops (UPat(Ops.BARRIER, dtypes.void), lambda: True), # BARRIERs can also happen at the end of loops
# NOTE: for testing, we let sinks be anything
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
(UPat(Ops.SINK, dtypes.void), lambda: True),
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True), (UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
]) ])+shared_spec
# *** this is the UOp AST spec ***
ast_spec = PatternMatcher([
# all parent UOps must have the same shape
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
])
# *** this spec should match all UOps ever created *** # *** this spec should match all UOps ever created ***
full_spec = PatternMatcher([ full_spec = PatternMatcher([
# any END
(UPat(Ops.END), lambda: True),
# SENTINEL should never be in the graph # SENTINEL should never be in the graph
(UPat(Ops.SENTINEL), lambda: False), (UPat(Ops.SENTINEL), lambda: False),
@@ -254,6 +175,8 @@ full_spec = PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat(GroupOp.Movement))), lambda: True), (UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat(GroupOp.Movement))), lambda: True),
# expander: unroll/contract/gep/ptrcat/cat # expander: unroll/contract/gep/ptrcat/cat
#(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
#(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
(UPat((Ops.UNROLL, Ops.CONTRACT), src=(UPat(),)), lambda: True), (UPat((Ops.UNROLL, Ops.CONTRACT), src=(UPat(),)), lambda: True),
# GEP multi is supported here # GEP multi is supported here
(UPat(Ops.GEP, name="gep"), lambda gep: gep.dtype is dtypes.void or gep.dtype.vcount == len(gep.arg)), (UPat(Ops.GEP, name="gep"), lambda gep: gep.dtype is dtypes.void or gep.dtype.vcount == len(gep.arg)),
@@ -281,12 +204,11 @@ full_spec = PatternMatcher([
(UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True), (UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True),
# allow any AFTER # allow any AFTER
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True), (UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
])+tensor_uop_spec+spec ])+tensor_spec+program_spec
# ***** uop helpers ***** # ***** uop helpers *****
def type_verify(uops:list[UOp], extra_spec:PatternMatcher|None=None): def type_verify(uops:list[UOp], check_spec:PatternMatcher):
check_spec = (extra_spec+spec) if extra_spec is not None else spec
for i,u in enumerate(uops): for i,u in enumerate(uops):
with Context(TRACK_MATCH_STATS=0): ret = check_spec.rewrite(u) with Context(TRACK_MATCH_STATS=0): ret = check_spec.rewrite(u)
if cast(bool|None, ret) is not True: if cast(bool|None, ret) is not True:

79
tinygrad/uop/validate.py Normal file
View File

@@ -0,0 +1,79 @@
from typing import Callable
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, python_alu, graph_rewrite
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.helpers import IGNORE_OOB, Context, cpu_profile
try:
import z3
# older versions of z3 dont have some operators like & overloaded
if z3.get_version() < (4, 12, 4, 0): raise ImportError
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
def z3_cdiv(a, b):return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
def z3_xor(a,b):
if isinstance(a, z3.BoolRef): return a^b
assert a==-1 or b==-1, "xor can only be used in indexing if one of the aruments is -1"
return -a-1 if b==-1 else -b-1
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor,
Ops.MAX: lambda a,b: z3.If(a<b, b, a), Ops.TRUNC: lambda a: a if a.is_int() else z3.ToReal(z3.If(a >= 0, z3.ToInt(a), -z3.ToInt(-a)))}
def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef:
s = z3.Int(name, ctx=solver.ctx)
solver.add(vmin <= s, s <= vmax)
return s
# ctx is (solver, load_number_dict)
# each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different
# contexts can have the same hash but error on comparison
z3_renderer = PatternMatcher([
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))),
# loaded bools become a z3 int with min max of 0-1
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)),
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"),
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))),
# z3 can cast from bool to int automatically
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
# if the source of the cast is not a noop it means that it is a float and so we create a new variable
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))),
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))),
# A comparison between floats introduces a new bool variable
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
])
def uops_to_z3(solver, *uops: UOp) -> 'list[z3.ExprRef]':
with Context(TRACK_MATCH_STATS=0, SPEC=0): # cant pickle z3 objects, and these UOps don't follow spec
return [s.arg[1] for s in graph_rewrite(uops[0].sink(*uops[1:]), z3_renderer, ctx=(solver, {})).src]
z3_imported = True
except (ImportError, AttributeError): z3_imported = False
def validate_index(idx:UOp, gate:UOp|None=None):
if gate is None: gate = UOp.const(dtypes.bool, True)
# TODO: check for overflow
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
mask = idx.src[2]&gate if len(idx.src)==3 else gate
# WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
solver = z3.Solver(ctx=z3.Context())
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], mask)
solver.add(z3_mask)
with cpu_profile("validate index with z3", "TINY"):
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
print(f"idx={idx.src[1].render(simplify=False)}")
print(f"mask & gate={mask.render(simplify=False)}")
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
return False
return True