mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
kernelize prereqs [pr] (#9811)
* kernelize prereqs [pr] * work * tensor maps to assign * unwrap st * process replay * grouper changes * replay
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, merge_views
|
||||
from tinygrad.ops import can_pad, sint
|
||||
from tinygrad.ops import can_pad, sint, track_rewrites
|
||||
from tinygrad.codegen.lowerer import get_contraction
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -435,6 +435,11 @@ pm_fuse = PatternMatcher([
|
||||
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
|
||||
])
|
||||
|
||||
def get_name(ret:tuple[dict[UOp, UOp], dict[Variable, int]]) -> str:
|
||||
kcount = len({u.src[1] for u in ret[0].values() if u.op is Ops.ASSIGN})
|
||||
return f"Schedule {pluralize('Kernel', kcount)}"+(f" (with_{pluralize('Var', len(ret[1]))})" if ret[1] else "")
|
||||
|
||||
@track_rewrites(name_fxn=get_name)
|
||||
def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
|
||||
# merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous+pm_fuse, ctx={})
|
||||
@@ -455,15 +460,9 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
|
||||
for k,v in tensor_map.items():
|
||||
if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
|
||||
if k is v: continue
|
||||
if k.op is Ops.ASSIGN:
|
||||
becomes_map[k] = k.src[0]
|
||||
continue
|
||||
op = v.base.op
|
||||
if op is Ops.BUFFER: becomes_map[k] = v
|
||||
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v
|
||||
if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
||||
if op is Ops.ASSIGN:
|
||||
new_buf = v.base.src[0]
|
||||
becomes_map[k] = new_buf if new_buf.st == v.st else new_buf.view(unwrap(v.st))
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import atexit, pickle
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
from tinygrad.ops import UOp, Variable, Ops, buffers, track_rewrites
|
||||
from tinygrad.ops import UOp, Variable, Ops, buffers
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import Metadata, CAPTURE_PROCESS_REPLAY, DEBUG, Context, ContextVar, diskcache_put, pluralize
|
||||
from tinygrad.helpers import Metadata, CAPTURE_PROCESS_REPLAY, DEBUG, Context, ContextVar, diskcache_put, unwrap
|
||||
from tinygrad.engine.grouper import get_becomes_map
|
||||
|
||||
# **** ScheduleItem return type
|
||||
@@ -22,8 +22,6 @@ if CAPTURE_PROCESS_REPLAY:
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
|
||||
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
becomes_map, var_vals = get_becomes_map(big_sink)
|
||||
sched_sink = becomes_map.pop(big_sink)
|
||||
@@ -59,4 +57,12 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
|
||||
|
||||
# map ASSIGN to BUFFER after ScheduleItems are constructed
|
||||
for k,v in becomes_map.items():
|
||||
if v.base.op is Ops.ASSIGN:
|
||||
# if the UOp was already an assign Tensor UOp we just map it to the existing buffer
|
||||
if k.op is Ops.ASSIGN: becomes_map[k] = k.src[0]
|
||||
# otherwise we map it to the new buffer, ignoring NOOP ShapeTrackers
|
||||
else: becomes_map[k] = new_buf if (new_buf:=v.base.src[0]).st == v.st else new_buf.view(unwrap(v.st))
|
||||
|
||||
return schedule, var_vals, becomes_map
|
||||
|
||||
Reference in New Issue
Block a user