diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index dd63c9c470..f59106a71d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,7 +1,7 @@ import sys, atexit, functools, itertools from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast +from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict, cast from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG @@ -23,7 +23,7 @@ class ScheduleItem: ast: UOp bufs: Tuple[Buffer, ...] metadata: Tuple[Metadata, ...] - assign_preloads: Tuple[UOp, ...] + assign_preloads: FrozenSet[UOp] @property def outputs(self) -> Tuple[Buffer, ...]: """Read/write or write only buffers in the schedule.""" @@ -186,7 +186,7 @@ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)]) def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp: - if b in ctx.assigned: ctx.assign_preloads.append(b) + if b in ctx.assigned: ctx.assign_preloads.append(x) return x.replace(op=Ops.LOAD) to_si = PatternMatcher([ @@ -220,14 +220,13 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleContext) -> Tuple[UOp, ScheduleItemCon raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.") # do movementops sink = graph_rewrite(graph_rewrite(sink, view_left), view_right) - # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine - if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0: - if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \ - and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets): - raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" - +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) # convert to AST sink = graph_rewrite(graph_rewrite(sink, to_si, si_ctx), append_bufs, si_ctx) + # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine + if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \ + and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in si_ctx.assign_preloads if si_ctx.sinked.get(x.buf_uop) is not None): + raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" + +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, ctx), sink)) return sink, si_ctx @@ -389,7 +388,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] for store_uops in store_groups: ast, ast_ctx = full_ast_rewrite(UOp.sink(*(realizes[u] for u in store_uops)), ctx) prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=buffers[u]).size != 0), - tuple(ast_ctx.metadata), tuple(ast_ctx.assign_preloads))) + tuple(ast_ctx.metadata), frozenset(x.buf_uop for x in ast_ctx.assign_preloads))) # do BFS schedule_targets = {out:si for si in prescheduled for out in si.outputs} graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)