diff --git a/test/test_schedule.py b/test/test_schedule.py index 68cd7180a0..7e3801e61f 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -10,9 +10,8 @@ from hypothesis import assume, given, settings, strategies as strat from tinygrad import nn, dtypes, Device, Tensor, Variable from tinygrad.device import is_dtype_supported from tinygrad.dtype import DType, ImageDType -from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat +from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp -from tinygrad.schedule.rangeify import Kernel from tinygrad.engine.realize import CompiledRunner, run_schedule class KernelCountException(Exception): pass diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 27891df25a..a1c57ac8ed 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,7 +1,7 @@ import time from typing import cast from collections import deque -from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel +from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata @@ -22,14 +22,14 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: # build kernel dependency graph: edges from producer kernel to consumer kernels children: dict[UOp, list[UOp]] = {} in_degree: dict[UOp, int] = {} - for u in sched_sink.toposort(): + for u in sched_sink.toposort(gate_kernel_sink): if u.op is Ops.RANGE: in_degree.setdefault(u, 0) if u.op is not Ops.AFTER: continue if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}" in_degree.setdefault(k, 0) if k.op is Ops.END: assert k.src[0].op is Ops.KERNEL, f"END src[0] should be KERNEL, not {k.src[0].op}" - for s in k.src[0].src if k.op is Ops.END else k.src[1:]: + for s in k.src[0].src[1:] if k.op is Ops.END else k.src[1:]: match (s := _unwrap_src(s)).op: case Ops.AFTER: children.setdefault(s.src[1], []).append(k) @@ -60,7 +60,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: #ast = (kernel:=cast(Kernel, k.arg)).ast ast = k.src[0] buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND) - bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE) + bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE) sched_item[k] = (ast, buf_uops, k.arg.metadata, bound_ranges) schedule.append(k) if rk.op is Ops.END: schedule.append(rk) diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index d8d4fcfbc9..2f5b424a45 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -3,9 +3,9 @@ import functools, itertools from dataclasses import dataclass, field from tinygrad.dtype import dtypes, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches -from tinygrad.uop.ops import consumer_map_from_toposort +from tinygrad.uop.ops import consumer_map_from_toposort, KernelInfo, BottomUpGate, gate_kernel_sink from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses -from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored +from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, panic ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM, @@ -19,9 +19,12 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None: def realize_assign(ctx:dict[UOp, None], a:UOp) -> None: if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None - ctx[a] = None + # if it's a kernel, we don't realize it + if a.src[1].op is not Ops.KERNEL: ctx[a] = None pm_generate_realize_map = PatternMatcher([ + # if it's a Kernel, stop + (UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate()) if isinstance(sink.arg, KernelInfo) else None), # always realize SINK src (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)), # always realize COPY/BUFFER_VIEW/CONTIGUOUS/STORE/ENCDEC @@ -161,11 +164,11 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: rctx = IndexingContext() # get ops to realize - graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize") + graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize") # get the consumer map with cpu_profile("consumer map in rangeify", "TINY"): - consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort()) + consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort(gate_kernel_sink)) # explicit rangeify ending_ranges: dict[UOp, list[UOp]] = {} diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 4c77d4c135..ea3eb2e661 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo -from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str +from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY from tinygrad.helpers import PCONTIG, partition, get_single_element, panic @@ -66,9 +66,12 @@ mop_cleanup = PatternMatcher([ def resolve_custom_kernel(ck:UOp) -> UOp: placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)] - return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders))) + return ck.arg.fxn(*placeholders).call(*ck.src) -def resolve_call(c:UOp) -> UOp: +def resolve_call(c:UOp) -> UOp|None: + # don't resolve real kernel calls, sink or program + if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None + if c.src[0].op is Ops.PROGRAM: return None params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg) args = c.src[1:] # TODO: this check belongs in spec, not here diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 325689e9fd..d1e945906a 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -67,7 +67,8 @@ def consumer_map_from_toposort(lst:Iterable[UOp]): ret: dict[UOp, dict[UOp, None]] = {} for u in lst: ret[u] = {} - for s in u.src: ret[s][u] = None + for s in u.src: + if s in ret: ret[s][u] = None return ret def pretty_print(x:UOp, cache=None, d=0)->str: @@ -236,7 +237,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return self.src[0]._shape # ops with custom handling - case Ops.KERNEL: return self.arg.ast._shape + #case Ops.KERNEL: return self.arg.ast._shape # TODO: disallow shape changing bitcast case Ops.BITCAST: @@ -836,6 +837,8 @@ class KernelInfo: @property def function_name(self): return to_function_name(self.name) +def gate_kernel_sink(x:UOp) -> bool: return not (x.op is Ops.SINK and isinstance(x.arg, KernelInfo)) + @dataclass(frozen=True) class CustomKernel: fxn: Callable @@ -1448,10 +1451,10 @@ def pyrender(ast:UOp) -> str: if op_depth > 100: to_render.add(u) depth[u] = 0 if u in to_render else op_depth # do the rendering - if u.op is Ops.KERNEL: - if u.arg.ast not in kernels: - kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n") - r[u.arg.ast] = kernels[u.arg.ast][0] + #if u.op is Ops.KERNEL: + # if u.arg.ast not in kernels: + # kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n") + # r[u.arg.ast] = kernels[u.arg.ast][0] ren = cast(str, pm_pyrender.rewrite(u, ctx=r)) assert isinstance(ren, str) if u.tag is not None: ren += f".rtag({repr(u.tag)})" diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index ff8d981e0d..f68fe5d2b9 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", Ops.PARAM:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55", + Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.CUSTOM_KERNEL: "#3ebf55", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6", Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",