mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
start with a fresh ScheduleItemContext in process_replay [pr] (#7236)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import sys, atexit, functools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import BUFFER_UOPS, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
|
||||
graph_rewrite, track_rewrites, sint
|
||||
@@ -132,8 +132,8 @@ view_right = merge_views+PatternMatcher([
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
var_vals: Dict[Variable, int]
|
||||
sts: Set[ShapeTracker]
|
||||
bufs: List[UOp]
|
||||
sts: Set[ShapeTracker] = field(default_factory=set)
|
||||
bufs: List[UOp] = field(default_factory=list)
|
||||
|
||||
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
||||
if (st:=unwrap(x.st)) in ctx.sts: return None
|
||||
@@ -158,7 +158,7 @@ PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, ScheduleItemContext, UOp]] = []
|
||||
def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right)
|
||||
ret = graph_rewrite(sink, to_si, ctx)
|
||||
PROCESS_REPLAY_CAPTURE.append((base_sink, ctx, ret))
|
||||
PROCESS_REPLAY_CAPTURE.append((base_sink, ScheduleItemContext(ctx.var_vals), ret))
|
||||
return ret
|
||||
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
@@ -198,7 +198,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val
|
||||
metadata: Dict[UOp, Metadata] = {}
|
||||
sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(),
|
||||
to_uop(out, outs, inputs, buf_uops, metadata, cache)) for out in outs))
|
||||
sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals, set(), []))
|
||||
sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals))
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
|
||||
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||
|
||||
Reference in New Issue
Block a user