This commit is contained in:
George Hotz
2026-02-06 13:54:15 +08:00
parent 7391ab77f2
commit 85ee86f2fb
6 changed files with 29 additions and 21 deletions

View File

@@ -10,9 +10,8 @@ from hypothesis import assume, given, settings, strategies as strat
from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.schedule.rangeify import Kernel
from tinygrad.engine.realize import CompiledRunner, run_schedule
class KernelCountException(Exception): pass

View File

@@ -1,7 +1,7 @@
import time
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
@@ -22,14 +22,14 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
# build kernel dependency graph: edges from producer kernel to consumer kernels
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort():
for u in sched_sink.toposort(gate_kernel_sink):
if u.op is Ops.RANGE: in_degree.setdefault(u, 0)
if u.op is not Ops.AFTER: continue
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
in_degree.setdefault(k, 0)
if k.op is Ops.END: assert k.src[0].op is Ops.KERNEL, f"END src[0] should be KERNEL, not {k.src[0].op}"
for s in k.src[0].src if k.op is Ops.END else k.src[1:]:
for s in k.src[0].src[1:] if k.op is Ops.END else k.src[1:]:
match (s := _unwrap_src(s)).op:
case Ops.AFTER:
children.setdefault(s.src[1], []).append(k)
@@ -60,7 +60,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
#ast = (kernel:=cast(Kernel, k.arg)).ast
ast = k.src[0]
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
sched_item[k] = (ast, buf_uops, k.arg.metadata, bound_ranges)
schedule.append(k)
if rk.op is Ops.END: schedule.append(rk)

View File

@@ -3,9 +3,9 @@ import functools, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
from tinygrad.uop.ops import consumer_map_from_toposort
from tinygrad.uop.ops import consumer_map_from_toposort, KernelInfo, BottomUpGate, gate_kernel_sink
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, panic
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
@@ -19,9 +19,12 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
ctx[a] = None
# if it's a kernel, we don't realize it
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
pm_generate_realize_map = PatternMatcher([
# if it's a Kernel, stop
(UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate()) if isinstance(sink.arg, KernelInfo) else None),
# always realize SINK src
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize COPY/BUFFER_VIEW/CONTIGUOUS/STORE/ENCDEC
@@ -161,11 +164,11 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
rctx = IndexingContext()
# get ops to realize
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize")
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize")
# get the consumer map
with cpu_profile("consumer map in rangeify", "TINY"):
consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort())
consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort(gate_kernel_sink))
# explicit rangeify
ending_ranges: dict[UOp, list[UOp]] = {}

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element, panic
@@ -66,9 +66,12 @@ mop_cleanup = PatternMatcher([
def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
return ck.arg.fxn(*placeholders).call(*ck.src)
def resolve_call(c:UOp) -> UOp:
def resolve_call(c:UOp) -> UOp|None:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
if c.src[0].op is Ops.PROGRAM: return None
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here

View File

@@ -67,7 +67,8 @@ def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {}
for u in lst:
ret[u] = {}
for s in u.src: ret[s][u] = None
for s in u.src:
if s in ret: ret[s][u] = None
return ret
def pretty_print(x:UOp, cache=None, d=0)->str:
@@ -236,7 +237,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.src[0]._shape
# ops with custom handling
case Ops.KERNEL: return self.arg.ast._shape
#case Ops.KERNEL: return self.arg.ast._shape
# TODO: disallow shape changing bitcast
case Ops.BITCAST:
@@ -836,6 +837,8 @@ class KernelInfo:
@property
def function_name(self): return to_function_name(self.name)
def gate_kernel_sink(x:UOp) -> bool: return not (x.op is Ops.SINK and isinstance(x.arg, KernelInfo))
@dataclass(frozen=True)
class CustomKernel:
fxn: Callable
@@ -1448,10 +1451,10 @@ def pyrender(ast:UOp) -> str:
if op_depth > 100: to_render.add(u)
depth[u] = 0 if u in to_render else op_depth
# do the rendering
if u.op is Ops.KERNEL:
if u.arg.ast not in kernels:
kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n")
r[u.arg.ast] = kernels[u.arg.ast][0]
#if u.op is Ops.KERNEL:
# if u.arg.ast not in kernels:
# kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n")
# r[u.arg.ast] = kernels[u.arg.ast][0]
ren = cast(str, pm_pyrender.rewrite(u, ctx=r))
assert isinstance(ren, str)
if u.tag is not None: ren += f".rtag({repr(u.tag)})"

View File

@@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.PARAM:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.CUSTOM_KERNEL: "#3ebf55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",