add kernel spec (#12911)

* add kernel spec

* fix kernel spec
This commit is contained in:
George Hotz
2025-10-25 11:49:20 +08:00
committed by GitHub
parent 8a941d95a4
commit b4f6a2c7a3
5 changed files with 39 additions and 20 deletions

View File

@@ -574,7 +574,7 @@ class TestUOpGraph(unittest.TestCase):
def test_in_out_bounds_access_with_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0")
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5<gidx0)&(gidx0<16)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<16),))
to_uops_list([ld0, ld1])
@@ -598,7 +598,7 @@ class TestUOpGraph(unittest.TestCase):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0")
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),))
to_uops_list([ld1])
@@ -834,7 +834,7 @@ class TestIFUOps(unittest.TestCase):
valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "gidx0")<1
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
gate = valid&(lidx.ne(2))
st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42)))
st = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx), UOp.const(dtypes.float, 42)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st,))
lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, i)), barrier)) for i in range(4)]
stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)]

View File

@@ -1,11 +1,12 @@
import unittest, math
from tinygrad import dtypes
from tinygrad.helpers import all_same
from tinygrad.helpers import all_same, Context
from tinygrad.uop.ops import GroupOp, UOp, Ops, exec_alu, PatternMatcher, TrackedPatternMatcher, UPat
from tinygrad.codegen import full_rewrite_to_sink
from hypothesis import given, strategies as strat
# Helper function to apply the graph rewrite
@Context(SPEC=0)
def apply_rewrite(expr):
return full_rewrite_to_sink(expr.sink()).src[0]

View File

@@ -47,7 +47,7 @@ class TestHelpers(unittest.TestCase):
class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid):
with Context(NOOPT=1):
with Context(NOOPT=1, SPEC=0):
load = full_rewrite_to_sink(load.sink()).src[0]
idx, valid = load.src[0].src[1], load.src[0].src[2]
check_uop_against_string(self, idx, sidx)
@@ -213,7 +213,7 @@ class TestValidIdxSimplification(unittest.TestCase):
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
with Context(NOOPT=1):
with Context(NOOPT=1, SPEC=0):
load = full_rewrite_to_sink(load.sink()).src[0]
idx = load.src[0].src[1]
self.assertEqual(idx.op, Ops.VECTORIZE)
@@ -283,7 +283,8 @@ class TestImageSimplification(unittest.TestCase):
# empty -> invalid
load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx)
load = full_rewrite_to_sink(load.sink()).src[0]
with Context(NOOPT=1, SPEC=0):
load = full_rewrite_to_sink(load.sink()).src[0]
self.assertEqual(load.op, Ops.VECTORIZE)
self.assertEqual(load.dtype.count, 4)

View File

@@ -1,6 +1,6 @@
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
from tinygrad.uop.spec import type_verify, program_spec
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer
# import all pattern matchers here
@@ -19,6 +19,8 @@ from tinygrad.codegen.late.control_flow import CFGContext, pm_split_ends, pm_add
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()
if SPEC: type_verify(list(sink.toposort()), kernel_spec)
# first we optimize
if optimize:
if QUANTIZE and ren.device in {"CPU", "DSP"}: sink = graph_rewrite(sink, pm_quant, name="quantize")

View File

@@ -1,12 +1,13 @@
from typing import cast
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.helpers import DEBUG, Context
from tinygrad.helpers import DEBUG, Context, prod
from tinygrad.uop.validate import validate_index
# four specs:
# shared_spec -- usable anywhere
# tensor_spec -- usable in tensor graph
# kernel_spec -- usable in kernel passed into codegen
# program_spec -- usable in linearized program
# full_spec -- all uops ever created
@@ -15,6 +16,9 @@ from tinygrad.uop.validate import validate_index
shared_spec = PatternMatcher([
(UPat(Ops.SINK, dtypes.void), lambda: True), # NOTE: for testing, we let sinks be anything
# SENTINEL should never be anywhere
(UPat(Ops.SENTINEL), lambda: False),
# CONST/DEFINE_VAR are everywhere
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
@@ -146,15 +150,33 @@ program_spec = PatternMatcher([
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
])+shared_spec
# ***** UOp spec in kernel graph *****
kernel_spec = PatternMatcher([
# index is allowed here
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
# UNROLL/CONTRACT is used here for WMMA
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# END can end multiple axes here
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), allow_any_len=True, dtype=dtypes.void), lambda: True),
# bufferize (must be on ranges)
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op in {Ops.RANGE, Ops.CONST} for y in x.src[1:])),
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
# intermediate index
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
])+program_spec+shared_spec
# *** this spec should match all UOps ever created ***
full_spec = PatternMatcher([
# any END
(UPat(Ops.END), lambda: True),
# SENTINEL should never be in the graph
(UPat(Ops.SENTINEL), lambda: False),
# Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
# where on index in rhs position is fine
@@ -165,19 +187,12 @@ full_spec = PatternMatcher([
# rangeify: buffer view with index or load is okay
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX, Ops.LOAD)),)), lambda: True),
# bufferize (must be on ranges)
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op in {Ops.RANGE, Ops.CONST} for y in x.src[1:])),
# intermediate index
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
# copy on index
(UPat(Ops.COPY, src=(UPat(Ops.INDEX), UPat())), lambda: True),
# assign on index. the third op is the shape
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat())), lambda: True),
# expander: unroll/contract/gep/ptrcat/cat
#(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
#(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
(UPat((Ops.UNROLL, Ops.CONTRACT), src=(UPat(),)), lambda: True),
# GEP multi is supported here
(UPat(Ops.GEP, name="gep"), lambda gep: gep.dtype is dtypes.void or gep.dtype.vcount == len(gep.arg)),
@@ -211,7 +226,7 @@ full_spec = PatternMatcher([
(UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True),
# allow any AFTER
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
])+tensor_spec+program_spec
])+tensor_spec+kernel_spec+program_spec+shared_spec
# ***** uop helpers *****