mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
faster assign view check [pr] (#7781)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import sys, atexit, functools, itertools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
|
||||
@@ -23,7 +23,7 @@ class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: Tuple[Buffer, ...]
|
||||
metadata: Tuple[Metadata, ...]
|
||||
assign_preloads: Tuple[UOp, ...]
|
||||
assign_preloads: FrozenSet[UOp]
|
||||
@property
|
||||
def outputs(self) -> Tuple[Buffer, ...]:
|
||||
"""Read/write or write only buffers in the schedule."""
|
||||
@@ -186,7 +186,7 @@ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
||||
|
||||
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
||||
if b in ctx.assigned: ctx.assign_preloads.append(b)
|
||||
if b in ctx.assigned: ctx.assign_preloads.append(x)
|
||||
return x.replace(op=Ops.LOAD)
|
||||
|
||||
to_si = PatternMatcher([
|
||||
@@ -220,14 +220,13 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemCon
|
||||
raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
|
||||
# do movementops
|
||||
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
|
||||
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# convert to AST
|
||||
sink = graph_rewrite(graph_rewrite(sink, to_si, si_ctx), append_bufs, si_ctx)
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in si_ctx.assign_preloads if si_ctx.sinked.get(x.buf_uop) is not None):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, ctx), sink))
|
||||
return sink, si_ctx
|
||||
|
||||
@@ -389,7 +388,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
for store_uops in store_groups:
|
||||
ast, ast_ctx = full_ast_rewrite(UOp.sink(*(realizes[u] for u in store_uops)), ctx)
|
||||
prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=buffers[u]).size != 0),
|
||||
tuple(ast_ctx.metadata), tuple(ast_ctx.assign_preloads)))
|
||||
tuple(ast_ctx.metadata), frozenset(x.buf_uop for x in ast_ctx.assign_preloads)))
|
||||
# do BFS
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
|
||||
|
||||
Reference in New Issue
Block a user