faster assign view check [pr] (#7781)

This commit is contained in:
qazal
2024-11-19 13:42:51 +02:00
committed by GitHub
parent 3daa376107
commit 8360bbd88d

View File

@@ -1,7 +1,7 @@
import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
@@ -23,7 +23,7 @@ class ScheduleItem:
ast: UOp
bufs: Tuple[Buffer, ...]
metadata: Tuple[Metadata, ...]
assign_preloads: Tuple[UOp, ...]
assign_preloads: FrozenSet[UOp]
@property
def outputs(self) -> Tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
@@ -186,7 +186,7 @@ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
if b in ctx.assigned: ctx.assign_preloads.append(b)
if b in ctx.assigned: ctx.assign_preloads.append(x)
return x.replace(op=Ops.LOAD)
to_si = PatternMatcher([
@@ -220,14 +220,13 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemCon
raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
# do movementops
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# convert to AST
sink = graph_rewrite(graph_rewrite(sink, to_si, si_ctx), append_bufs, si_ctx)
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in si_ctx.assign_preloads if si_ctx.sinked.get(x.buf_uop) is not None):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, ctx), sink))
return sink, si_ctx
@@ -389,7 +388,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
for store_uops in store_groups:
ast, ast_ctx = full_ast_rewrite(UOp.sink(*(realizes[u] for u in store_uops)), ctx)
prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=buffers[u]).size != 0),
tuple(ast_ctx.metadata), tuple(ast_ctx.assign_preloads)))
tuple(ast_ctx.metadata), frozenset(x.buf_uop for x in ast_ctx.assign_preloads)))
# do BFS
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)