From b4f6a2c7a384dc63b9a5e7e866bce5aa42bd13a6 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 25 Oct 2025 11:49:20 +0800 Subject: [PATCH] add kernel spec (#12911) * add kernel spec * fix kernel spec --- test/test_uop_graph.py | 6 ++--- test/unit/test_graph_rewrite.py | 3 ++- test/unit/test_simplify_valid_idx.py | 7 ++--- tinygrad/codegen/__init__.py | 4 ++- tinygrad/uop/spec.py | 39 +++++++++++++++++++--------- 5 files changed, 39 insertions(+), 20 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 1f5a849fcf..d85e25ff30 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -574,7 +574,7 @@ class TestUOpGraph(unittest.TestCase): def test_in_out_bounds_access_with_mask(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0") + gidx0 = UOp.range(42, 0, AxisType.GLOBAL) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5=0)&(ld0<32)),)) to_uops_list([ld1]) @@ -834,7 +834,7 @@ class TestIFUOps(unittest.TestCase): valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "gidx0")<1 lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0") gate = valid&(lidx.ne(2)) - st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42))) + st = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx), UOp.const(dtypes.float, 42))) barrier = UOp(Ops.BARRIER, dtypes.void, (st,)) lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, i)), barrier)) for i in range(4)] stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)] diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index ea9c271aec..bff8c7d055 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -1,11 +1,12 @@ import unittest, math from tinygrad import dtypes -from tinygrad.helpers import all_same +from tinygrad.helpers import all_same, Context from tinygrad.uop.ops import GroupOp, UOp, Ops, exec_alu, PatternMatcher, TrackedPatternMatcher, UPat from tinygrad.codegen import full_rewrite_to_sink from hypothesis import given, strategies as strat # Helper function to apply the graph rewrite +@Context(SPEC=0) def apply_rewrite(expr): return full_rewrite_to_sink(expr.sink()).src[0] diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index dfaee9e58d..ccea0deb21 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -47,7 +47,7 @@ class TestHelpers(unittest.TestCase): class TestValidIdxSimplification(unittest.TestCase): def check(self, load, sidx, svalid): - with Context(NOOPT=1): + with Context(NOOPT=1, SPEC=0): load = full_rewrite_to_sink(load.sink()).src[0] idx, valid = load.src[0].src[1], load.src[0].src[2] check_uop_against_string(self, idx, sidx) @@ -213,7 +213,7 @@ class TestValidIdxSimplification(unittest.TestCase): class TestImageSimplification(unittest.TestCase): def check(self, load, svalid, sidx0, sidx1): - with Context(NOOPT=1): + with Context(NOOPT=1, SPEC=0): load = full_rewrite_to_sink(load.sink()).src[0] idx = load.src[0].src[1] self.assertEqual(idx.op, Ops.VECTORIZE) @@ -283,7 +283,8 @@ class TestImageSimplification(unittest.TestCase): # empty -> invalid load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx) - load = full_rewrite_to_sink(load.sink()).src[0] + with Context(NOOPT=1, SPEC=0): + load = full_rewrite_to_sink(load.sink()).src[0] self.assertEqual(load.op, Ops.VECTORIZE) self.assertEqual(load.dtype.count, 4) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 99416e44c7..81c6058722 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -1,6 +1,6 @@ from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype -from tinygrad.uop.spec import type_verify, program_spec +from tinygrad.uop.spec import type_verify, program_spec, kernel_spec from tinygrad.renderer import Renderer # import all pattern matchers here @@ -19,6 +19,8 @@ from tinygrad.codegen.late.control_flow import CFGContext, pm_split_ends, pm_add def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp: if ren is None: ren = Renderer() + if SPEC: type_verify(list(sink.toposort()), kernel_spec) + # first we optimize if optimize: if QUANTIZE and ren.device in {"CPU", "DSP"}: sink = graph_rewrite(sink, pm_quant, name="quantize") diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 6af4aa943d..2045e590a4 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -1,12 +1,13 @@ from typing import cast from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid -from tinygrad.helpers import DEBUG, Context +from tinygrad.helpers import DEBUG, Context, prod from tinygrad.uop.validate import validate_index # four specs: # shared_spec -- usable anywhere # tensor_spec -- usable in tensor graph +# kernel_spec -- usable in kernel passed into codegen # program_spec -- usable in linearized program # full_spec -- all uops ever created @@ -15,6 +16,9 @@ from tinygrad.uop.validate import validate_index shared_spec = PatternMatcher([ (UPat(Ops.SINK, dtypes.void), lambda: True), # NOTE: for testing, we let sinks be anything + # SENTINEL should never be anywhere + (UPat(Ops.SENTINEL), lambda: False), + # CONST/DEFINE_VAR are everywhere (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), @@ -146,15 +150,33 @@ program_spec = PatternMatcher([ (UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True), ])+shared_spec +# ***** UOp spec in kernel graph ***** + +kernel_spec = PatternMatcher([ + # index is allowed here + (UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True), + + # UNROLL/CONTRACT is used here for WMMA + (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), + (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), + + # END can end multiple axes here + (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), allow_any_len=True, dtype=dtypes.void), lambda: True), + + # bufferize (must be on ranges) + (UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op in {Ops.RANGE, Ops.CONST} for y in x.src[1:])), + (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), + + # intermediate index + (UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None), +])+program_spec+shared_spec + # *** this spec should match all UOps ever created *** full_spec = PatternMatcher([ # any END (UPat(Ops.END), lambda: True), - # SENTINEL should never be in the graph - (UPat(Ops.SENTINEL), lambda: False), - # Invalid must have type Index (UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index), # where on index in rhs position is fine @@ -165,19 +187,12 @@ full_spec = PatternMatcher([ # rangeify: buffer view with index or load is okay (UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX, Ops.LOAD)),)), lambda: True), - # bufferize (must be on ranges) - (UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op in {Ops.RANGE, Ops.CONST} for y in x.src[1:])), - # intermediate index - (UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None), - (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), # copy on index (UPat(Ops.COPY, src=(UPat(Ops.INDEX), UPat())), lambda: True), # assign on index. the third op is the shape (UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat())), lambda: True), # expander: unroll/contract/gep/ptrcat/cat - #(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), - #(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), (UPat((Ops.UNROLL, Ops.CONTRACT), src=(UPat(),)), lambda: True), # GEP multi is supported here (UPat(Ops.GEP, name="gep"), lambda gep: gep.dtype is dtypes.void or gep.dtype.vcount == len(gep.arg)), @@ -211,7 +226,7 @@ full_spec = PatternMatcher([ (UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True), # allow any AFTER (UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True), -])+tensor_spec+program_spec +])+tensor_spec+kernel_spec+program_spec+shared_spec # ***** uop helpers *****