schedule_uop api refactor [pr] (#8259)

This commit is contained in:
qazal
2024-12-15 13:50:00 +02:00
committed by GitHub
parent ef1346ab39
commit 58b224a40f

View File

@@ -1,7 +1,7 @@
import sys, atexit, functools, pickle
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict
from typing import Set, Tuple, List, Dict, Optional, DefaultDict
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.ops import identity_element, buffers, exec_alu
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
@@ -23,7 +23,7 @@ class ScheduleItem:
ast: UOp
bufs: Tuple[Buffer, ...]
metadata: Tuple[Metadata, ...]
assign_preloads: FrozenSet[UOp]
assign_preloads: Tuple[UOp, ...]
@property
def outputs(self) -> Tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
@@ -171,7 +171,7 @@ add_assign_adjacents = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x
# late folding for multi output kernels
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]:
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
# create the ast context
si_ctx = ScheduleItemContext(ctx.tensor_uops, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
create_ctx = add_metadata if len(si_ctx.assigns) == 0 else add_metadata+add_assign_adjacents
@@ -190,9 +190,14 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemCon
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in ops):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# can only schedule once
for buf_uop in si_ctx.sinked:
for luop in si_ctx.tensor_uops[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st)))
# capture process replay
if getenv("RUN_PROCESS_REPLAY"):
PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, {k:v.value for k,v in ContextVar._cache.items()}, sink))
return sink, si_ctx
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), tuple(si_ctx.metadata),
tuple(ubuf for ubuf,ops in si_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))
PROCESS_REPLAY_CAPTURE: Dict[str, bytes] = {}
if getenv("RUN_PROCESS_REPLAY"):
@@ -507,11 +512,7 @@ def create_schedule_with_vars(outs:List[UOp]) -> Tuple[List[ScheduleItem], Dict[
prescheduled: List[ScheduleItem] = []
for store_uops in store_groups:
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) != 0:
ast, ast_ctx = full_ast_rewrite(UOp.sink(*stores), ctx)
prescheduled.append(ScheduleItem(ast, tuple(u.buffer for u in ast_ctx.bufs if u.size != 0), tuple(ast_ctx.metadata),
frozenset(ubuf for ubuf,ops in ast_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops))))
for buf_uop in ast_ctx.sinked:
for luop in ast_ctx.tensor_uops[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st)))
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
# do BFS
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)