From 7762b3558ba7d99882cafaf05ffae3dd0f630878 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 22 Oct 2025 19:50:42 +0800 Subject: [PATCH] clean up the spec (#12868) * tighten up the spec * move validate into a different file * that moved to validate * after(barr) --- test/external/external_benchmark_schedule.py | 4 +- test/external/fuzz_fast_idiv.py | 2 +- test/external/fuzz_symbolic.py | 2 +- test/test_uop_graph.py | 6 +- test/test_uops.py | 8 +- test/unit/test_uop_symbolic.py | 2 +- tinygrad/codegen/__init__.py | 4 +- tinygrad/tensor.py | 4 +- tinygrad/uop/spec.py | 220 ++++++------------- tinygrad/uop/validate.py | 79 +++++++ 10 files changed, 166 insertions(+), 165 deletions(-) create mode 100644 tinygrad/uop/validate.py diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index d377969cd5..40f6a2114b 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -4,7 +4,7 @@ from tinygrad.helpers import Profiling, Timing, getenv from tinygrad.uop.ops import Ops from tinygrad.codegen import full_rewrite_to_sink from tinygrad.codegen.late.control_flow import linearize -from tinygrad.uop.spec import type_verify +from tinygrad.uop.spec import type_verify, program_spec if __name__ == "__main__": mdl = ResNet50() @@ -41,5 +41,5 @@ if __name__ == "__main__": for u in rewritten_uops: uops_line.append(linearize(u)) with Timing("***** model verify in "): - for u in uops_line: type_verify(u) + for u in uops_line: type_verify(u, program_spec) print(sum(len(u) for u in uops_line)) diff --git a/test/external/fuzz_fast_idiv.py b/test/external/fuzz_fast_idiv.py index a6e48f1d8a..8d6e556b1a 100644 --- a/test/external/fuzz_fast_idiv.py +++ b/test/external/fuzz_fast_idiv.py @@ -1,7 +1,7 @@ import random import z3 from tinygrad import dtypes -from tinygrad.uop.spec import uops_to_z3, z3_cdiv +from tinygrad.uop.validate import uops_to_z3, z3_cdiv from tinygrad.uop.ops import UOp from tinygrad.uop.decompositions import fast_idiv random.seed(42) diff --git a/test/external/fuzz_symbolic.py b/test/external/fuzz_symbolic.py index 0e87883d6b..060ce2ee8b 100644 --- a/test/external/fuzz_symbolic.py +++ b/test/external/fuzz_symbolic.py @@ -2,7 +2,7 @@ import random, operator import z3 from tinygrad import Variable, dtypes from tinygrad.uop.ops import UOp -from tinygrad.uop.spec import uops_to_z3 +from tinygrad.uop.validate import uops_to_z3 from tinygrad.helpers import DEBUG, Context seed = random.randint(0, 100) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 0409282391..4fd01cdcfd 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -639,13 +639,13 @@ class TestUOpGraph(unittest.TestCase): lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0") st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int))) barrier = UOp(Ops.BARRIER, dtypes.void, (st, )) - ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(UOp.invalid()), barrier)) - ld1 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+2, UOp.const(dtypes.bool, True)), barrier)) + ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),)) + ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(lidx+2, UOp.const(dtypes.bool, True)),)) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))]) ld0 = uops[-1].src[-1] # the gate and invalid value are deleted from ld1 - self.assertEqual(ld0.src[0], smem.index(lidx+2)) + self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2)) def test_fold_gated_store(self): glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) diff --git a/test/test_uops.py b/test/test_uops.py index 1f3c9b4ad9..e0d2fb6cd3 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -6,7 +6,7 @@ from tinygrad.helpers import CI, DEBUG, getenv, Timing from tinygrad.dtype import dtypes, DType, AddrSpace from tinygrad.device import Buffer, Device from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401 -from tinygrad.uop.spec import spec +from tinygrad.uop.spec import shared_spec from tinygrad.renderer import ProgramSpec from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.codegen import full_rewrite @@ -332,7 +332,7 @@ class TestLocalAccess(unittest.TestCase): smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem') st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0))) barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) - sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) + sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),)) self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42) # NOTE: webgpu specific, since only webgpu performs bitpacking @@ -342,7 +342,7 @@ class TestLocalAccess(unittest.TestCase): smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem') st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42))) barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) - sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) + sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),)) self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42) # NOTE: webgpu specific, since only webgpu performs bitpacking @@ -513,7 +513,7 @@ class TestUOpStr(unittest.TestCase): class TestUPatHelpers(unittest.TestCase): def test_location(self): self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "symbolic.py") - self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py") + self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py") test_upat = UPat(Ops.CONST, dtypes.bool) self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1]) test_upat_named = test_upat.named("test_name") diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 3c1805ff3b..cdbbc265f9 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -7,7 +7,7 @@ from tinygrad.codegen import full_rewrite from tinygrad.helpers import Context from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad.uop.symbolic import sym, commutative -from tinygrad.uop.spec import uops_to_z3 +from tinygrad.uop.validate import uops_to_z3 def check_uop_against_string(self, v:UOp, s:str): sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL)} diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 93686866b2..d51f29c57a 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -1,6 +1,6 @@ from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype -from tinygrad.uop.spec import type_verify +from tinygrad.uop.spec import type_verify, program_spec from tinygrad.renderer import Renderer # import all pattern matchers here @@ -98,5 +98,5 @@ def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]: """ lst = linearize(full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)) - if __debug__: type_verify(lst) + if __debug__: type_verify(lst, program_spec) return lst diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5ffa6be36a..95ed3397e7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -11,7 +11,7 @@ from tinygrad.helpers import suppress_finalizing from tinygrad.gradient import compute_gradient from tinygrad.uop.mathtraits import MathTrait from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, srender -from tinygrad.uop.spec import tensor_uop_spec, type_verify +from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule from tinygrad.engine.memory import memory_planner @@ -229,7 +229,7 @@ class Tensor(MathTrait): big_sink = UOp.sink(*[x.uop for x in (self,)+lst]) # verify Tensors match the spec - if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec) + if __debug__: type_verify(list(big_sink.toposort()), tensor_spec) if any(isinstance(x._device, tuple) for x in big_sink.toposort()): _apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map") diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 04919ad768..71a1379768 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -1,60 +1,45 @@ -from typing import cast, Callable -from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, AxisType +from typing import cast +from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid -from tinygrad.helpers import all_same, prod, DEBUG, IGNORE_OOB, Context, cpu_profile -try: - import z3 - # older versions of z3 dont have some operators like & overloaded - if z3.get_version() < (4, 12, 4, 0): raise ImportError +from tinygrad.helpers import DEBUG, Context +from tinygrad.uop.validate import validate_index - # IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND - def z3_cdiv(a, b):return z3.If((a<0), z3.If(0= 0, z3.ToInt(a), -z3.ToInt(-a)))} - def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef: - s = z3.Int(name, ctx=solver.ctx) - solver.add(vmin <= s, s <= vmax) - return s +# four specs: +# shared_spec -- usable anywhere +# tensor_spec -- usable in tensor graph +# program_spec -- usable in linearized program +# full_spec -- all uops ever created - # ctx is (solver, load_number_dict) - # each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different - # contexts can have the same hash but error on comparison - z3_renderer = PatternMatcher([ - (UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))), - (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))), - (UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))), - # loaded bools become a z3 int with min max of 0-1 - (UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx: - UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)), - (UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), - lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))), - # z3 can cast from bool to int automatically - (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), - (UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))), - # if the source of the cast is not a noop it means that it is a float and so we create a new variable - (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: - UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))), - (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx: - UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))), - (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))), - # A comparison between floats introduces a new bool variable - (UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx: - UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))), - ]) +# *** these uops work anywhere *** - def uops_to_z3(solver, *uops: UOp) -> 'list[z3.ExprRef]': - with Context(TRACK_MATCH_STATS=0, SPEC=0): # cant pickle z3 objects, and these UOps don't follow spec - return [s.arg[1] for s in graph_rewrite(uops[0].sink(*uops[1:]), z3_renderer, ctx=(solver, {})).src] +shared_spec = PatternMatcher([ + (UPat(Ops.SINK, dtypes.void), lambda: True), # NOTE: for testing, we let sinks be anything - z3_imported = True -except (ImportError, AttributeError): z3_imported = False + # CONST/DEFINE_VAR are everywhere + (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), + (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), -buffer_spec = PatternMatcher([ + # ALUs: most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE + (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype), + (UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base), + # and SHL/SHR, the shift distance can be an int + (UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)), + (UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), + (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)), + + # CAST + (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None), + + # RANGE can be in the big graph now + (UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x: + rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ + all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), +]) + +# ***** UOp spec in the Tensor graph ***** + +tensor_spec = PatternMatcher([ + # buffer spec (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), (UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d: isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))), @@ -63,9 +48,7 @@ buffer_spec = PatternMatcher([ (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"), lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)), (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True), -]) -assign_spec = PatternMatcher([ # KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), @@ -77,11 +60,7 @@ assign_spec = PatternMatcher([ # MSTACK combines buffers into multi (UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)), -]) -# *** this is the spec of a Tensor in UOp *** - -tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ (UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True), (UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True), (UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)), @@ -109,127 +88,69 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ (UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)), (UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)), + # REDUCE_AXIS is the reduce in the tensor graph + (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), + # REDUCE with an outerworld range (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), # AFTER if things were kernelized - (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True) -]) + (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), +])+shared_spec -# ***** uop type spec ***** +# ***** UOp spec in linearized programs ***** -def validate_index(idx:UOp, gate:UOp|None=None): - if gate is None: gate = UOp.const(dtypes.bool, True) - # TODO: check for overflow - if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True - # We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask - if 0<=idx.src[1].vmin and idx.src[1].vmax= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"") - solver = z3.Solver(ctx=z3.Context()) - z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], mask) - solver.add(z3_mask) - with cpu_profile("validate index with z3", "TINY"): - if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat: - print(f"idx={idx.src[1].render(simplify=False)}") - print(f"mask & gate={mask.render(simplify=False)}") - print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}") - return False - return True - -def validate_store(idx:UOp, val:UOp, gate:UOp|None=None): - if gate is None: gate = UOp.const(dtypes.bool, True) - if gate.op is Ops.IF: gate = gate.src[0] - # we need to find the implicit gates, inverse of delete_redundant_gates - for u in val.toposort(): - if u.op is Ops.IF: gate &= u.src[0] - return validate_index(idx, gate) - -index_pat = UPat(Ops.INDEX, name="idx").or_casted() - -# this is the matcher for the final rendered UOps -# matcher functions returns True or False (or None to not match) -spec = PatternMatcher([ +program_spec = PatternMatcher([ + # DEFINEs (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL), (UPat(Ops.DEFINE_REG, src=()), lambda: True), - (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), - - (UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x: - rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ - all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), - (UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)), - - (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), # allow AFTER on buffers (UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True), - # **** new style load/store **** + # INDEX is used in new style load/store + (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), + (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True), + + # LOAD (idx, alt_value) / LOAD(idx) / STORE(idx, val) + (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").or_casted(), UPat((Ops.VECTORIZE, Ops.VCONST, Ops.CONST)))), validate_index), + (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").or_casted(), )), validate_index), + (UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="idx").or_casted(), UPat())), validate_index), + + # RANGE/SPECIAL define loops, END closes them + (UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)), + (UPat(Ops.END, src=(UPat(Ops.RANGE), UPat()), allow_any_len=True, arg=1, dtype=dtypes.void), lambda: True), # make sure all index dtypes have been lowered (UPat(GroupOp.All, dtype=dtypes.index), lambda: False), (UPat(Ops.CONST, arg=Invalid), lambda: False), (UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)), - # INDEX is used in new style load/store - # INDEX takes a - (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True), - (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), - - # LOAD takes a - (UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])), - (UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index), - - # STORE takes a - (UPat(Ops.STORE, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store), - - # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE - (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype), - (UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base), - # and SHL/SHR, the shift distance can be an int - (UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)), - (UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), - (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)), - - (UPat(Ops.END, dtype=dtypes.void), lambda: True), - # WMMA has a (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8), - (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), - (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), - # if has a - (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),), allow_any_len=True), lambda: True), - (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),), allow_any_len=True), lambda: True), + # if has a + (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True), + (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), - (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), - (UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), + # VECTORIZE/GEP (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), - (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None), + (UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), + + # BARRIER (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local (UPat(Ops.BARRIER, dtypes.void), lambda: True), # BARRIERs can also happen at the end of loops - # NOTE: for testing, we let sinks be anything - #(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True), - (UPat(Ops.SINK, dtypes.void), lambda: True), (UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True), -]) - -# *** this is the UOp AST spec *** - -ast_spec = PatternMatcher([ - # all parent UOps must have the same shape - (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])), -]) +])+shared_spec # *** this spec should match all UOps ever created *** full_spec = PatternMatcher([ + # any END + (UPat(Ops.END), lambda: True), + # SENTINEL should never be in the graph (UPat(Ops.SENTINEL), lambda: False), @@ -254,6 +175,8 @@ full_spec = PatternMatcher([ (UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat(GroupOp.Movement))), lambda: True), # expander: unroll/contract/gep/ptrcat/cat + #(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), + #(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), (UPat((Ops.UNROLL, Ops.CONTRACT), src=(UPat(),)), lambda: True), # GEP multi is supported here (UPat(Ops.GEP, name="gep"), lambda gep: gep.dtype is dtypes.void or gep.dtype.vcount == len(gep.arg)), @@ -281,12 +204,11 @@ full_spec = PatternMatcher([ (UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True), # allow any AFTER (UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True), -])+tensor_uop_spec+spec +])+tensor_spec+program_spec # ***** uop helpers ***** -def type_verify(uops:list[UOp], extra_spec:PatternMatcher|None=None): - check_spec = (extra_spec+spec) if extra_spec is not None else spec +def type_verify(uops:list[UOp], check_spec:PatternMatcher): for i,u in enumerate(uops): with Context(TRACK_MATCH_STATS=0): ret = check_spec.rewrite(u) if cast(bool|None, ret) is not True: diff --git a/tinygrad/uop/validate.py b/tinygrad/uop/validate.py new file mode 100644 index 0000000000..63ab0dfe8a --- /dev/null +++ b/tinygrad/uop/validate.py @@ -0,0 +1,79 @@ +from typing import Callable +from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, python_alu, graph_rewrite +from tinygrad.dtype import ImageDType, dtypes +from tinygrad.helpers import IGNORE_OOB, Context, cpu_profile + +try: + import z3 + # older versions of z3 dont have some operators like & overloaded + if z3.get_version() < (4, 12, 4, 0): raise ImportError + + # IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND + def z3_cdiv(a, b):return z3.If((a<0), z3.If(0= 0, z3.ToInt(a), -z3.ToInt(-a)))} + def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef: + s = z3.Int(name, ctx=solver.ctx) + solver.add(vmin <= s, s <= vmax) + return s + + # ctx is (solver, load_number_dict) + # each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different + # contexts can have the same hash but error on comparison + z3_renderer = PatternMatcher([ + (UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))), + (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))), + (UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))), + # loaded bools become a z3 int with min max of 0-1 + (UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx: + UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)), + (UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), + lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))), + # z3 can cast from bool to int automatically + (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), + (UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))), + # if the source of the cast is not a noop it means that it is a float and so we create a new variable + (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: + UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))), + (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx: + UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))), + (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))), + # A comparison between floats introduces a new bool variable + (UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx: + UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))), + ]) + + def uops_to_z3(solver, *uops: UOp) -> 'list[z3.ExprRef]': + with Context(TRACK_MATCH_STATS=0, SPEC=0): # cant pickle z3 objects, and these UOps don't follow spec + return [s.arg[1] for s in graph_rewrite(uops[0].sink(*uops[1:]), z3_renderer, ctx=(solver, {})).src] + + z3_imported = True +except (ImportError, AttributeError): z3_imported = False + +def validate_index(idx:UOp, gate:UOp|None=None): + if gate is None: gate = UOp.const(dtypes.bool, True) + # TODO: check for overflow + if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True + # We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask + if 0<=idx.src[1].vmin and idx.src[1].vmax= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"") + solver = z3.Solver(ctx=z3.Context()) + z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], mask) + solver.add(z3_mask) + with cpu_profile("validate index with z3", "TINY"): + if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat: + print(f"idx={idx.src[1].render(simplify=False)}") + print(f"mask & gate={mask.render(simplify=False)}") + print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}") + return False + return True