kernelize prereqs [pr] (#9811)

* kernelize prereqs [pr]

* work

* tensor maps to assign

* unwrap st

* process replay

* grouper changes

* replay
This commit is contained in:
qazal
2025-04-10 15:22:20 +08:00
committed by GitHub
parent c462162db8
commit fd4f06e623
2 changed files with 18 additions and 13 deletions

View File

@@ -1,10 +1,10 @@
from collections import defaultdict, deque
from dataclasses import dataclass
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, merge_views
from tinygrad.ops import can_pad, sint
from tinygrad.ops import can_pad, sint, track_rewrites
from tinygrad.codegen.lowerer import get_contraction
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
@@ -435,6 +435,11 @@ pm_fuse = PatternMatcher([
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
])
def get_name(ret:tuple[dict[UOp, UOp], dict[Variable, int]]) -> str:
kcount = len({u.src[1] for u in ret[0].values() if u.op is Ops.ASSIGN})
return f"Schedule {pluralize('Kernel', kcount)}"+(f" (with_{pluralize('Var', len(ret[1]))})" if ret[1] else "")
@track_rewrites(name_fxn=get_name)
def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
# merge_views + simplify
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous+pm_fuse, ctx={})
@@ -455,15 +460,9 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
for k,v in tensor_map.items():
if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
if k is v: continue
if k.op is Ops.ASSIGN:
becomes_map[k] = k.src[0]
continue
op = v.base.op
if op is Ops.BUFFER: becomes_map[k] = v
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v
if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
if op is Ops.ASSIGN:
new_buf = v.base.src[0]
becomes_map[k] = new_buf if new_buf.st == v.st else new_buf.view(unwrap(v.st))
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}

View File

@@ -1,9 +1,9 @@
import atexit, pickle
from dataclasses import dataclass
from collections import deque
from tinygrad.ops import UOp, Variable, Ops, buffers, track_rewrites
from tinygrad.ops import UOp, Variable, Ops, buffers
from tinygrad.device import Buffer
from tinygrad.helpers import Metadata, CAPTURE_PROCESS_REPLAY, DEBUG, Context, ContextVar, diskcache_put, pluralize
from tinygrad.helpers import Metadata, CAPTURE_PROCESS_REPLAY, DEBUG, Context, ContextVar, diskcache_put, unwrap
from tinygrad.engine.grouper import get_becomes_map
# **** ScheduleItem return type
@@ -22,8 +22,6 @@ if CAPTURE_PROCESS_REPLAY:
# **** schedule linearizer
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
becomes_map, var_vals = get_becomes_map(big_sink)
sched_sink = becomes_map.pop(big_sink)
@@ -59,4 +57,12 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
# map ASSIGN to BUFFER after ScheduleItems are constructed
for k,v in becomes_map.items():
if v.base.op is Ops.ASSIGN:
# if the UOp was already an assign Tensor UOp we just map it to the existing buffer
if k.op is Ops.ASSIGN: becomes_map[k] = k.src[0]
# otherwise we map it to the new buffer, ignoring NOOP ShapeTrackers
else: becomes_map[k] = new_buf if (new_buf:=v.base.src[0]).st == v.st else new_buf.view(unwrap(v.st))
return schedule, var_vals, becomes_map