mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
schedule_uop api refactor [pr] (#8259)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import sys, atexit, functools, pickle
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict
|
||||
from typing import Set, Tuple, List, Dict, Optional, DefaultDict
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
|
||||
from tinygrad.ops import identity_element, buffers, exec_alu
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
@@ -23,7 +23,7 @@ class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: Tuple[Buffer, ...]
|
||||
metadata: Tuple[Metadata, ...]
|
||||
assign_preloads: FrozenSet[UOp]
|
||||
assign_preloads: Tuple[UOp, ...]
|
||||
@property
|
||||
def outputs(self) -> Tuple[Buffer, ...]:
|
||||
"""Read/write or write only buffers in the schedule."""
|
||||
@@ -171,7 +171,7 @@ add_assign_adjacents = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x
|
||||
# late folding for multi output kernels
|
||||
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
|
||||
|
||||
def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemContext]:
|
||||
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
||||
# create the ast context
|
||||
si_ctx = ScheduleItemContext(ctx.tensor_uops, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
|
||||
create_ctx = add_metadata if len(si_ctx.assigns) == 0 else add_metadata+add_assign_adjacents
|
||||
@@ -190,9 +190,14 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemCon
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in ops):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# can only schedule once
|
||||
for buf_uop in si_ctx.sinked:
|
||||
for luop in si_ctx.tensor_uops[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st)))
|
||||
# capture process replay
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, {k:v.value for k,v in ContextVar._cache.items()}, sink))
|
||||
return sink, si_ctx
|
||||
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), tuple(si_ctx.metadata),
|
||||
tuple(ubuf for ubuf,ops in si_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))
|
||||
|
||||
PROCESS_REPLAY_CAPTURE: Dict[str, bytes] = {}
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
@@ -507,11 +512,7 @@ def create_schedule_with_vars(outs:List[UOp]) -> Tuple[List[ScheduleItem], Dict[
|
||||
prescheduled: List[ScheduleItem] = []
|
||||
for store_uops in store_groups:
|
||||
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) != 0:
|
||||
ast, ast_ctx = full_ast_rewrite(UOp.sink(*stores), ctx)
|
||||
prescheduled.append(ScheduleItem(ast, tuple(u.buffer for u in ast_ctx.bufs if u.size != 0), tuple(ast_ctx.metadata),
|
||||
frozenset(ubuf for ubuf,ops in ast_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops))))
|
||||
for buf_uop in ast_ctx.sinked:
|
||||
for luop in ast_ctx.tensor_uops[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st)))
|
||||
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
|
||||
# do BFS
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
|
||||
|
||||
Reference in New Issue
Block a user