start with a fresh ScheduleItemContext in process_replay [pr] (#7236)

This commit is contained in:
qazal
2024-10-23 18:01:50 +03:00
committed by GitHub
parent ca6c58527b
commit ca7b2658b9

View File

@@ -1,6 +1,6 @@
import sys, atexit, functools
from collections import defaultdict, deque
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import BUFFER_UOPS, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
graph_rewrite, track_rewrites, sint
@@ -132,8 +132,8 @@ view_right = merge_views+PatternMatcher([
@dataclass(frozen=True)
class ScheduleItemContext:
var_vals: Dict[Variable, int]
sts: Set[ShapeTracker]
bufs: List[UOp]
sts: Set[ShapeTracker] = field(default_factory=set)
bufs: List[UOp] = field(default_factory=list)
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
if (st:=unwrap(x.st)) in ctx.sts: return None
@@ -158,7 +158,7 @@ PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, ScheduleItemContext, UOp]] = []
def full_ast_rewrite(base_sink:UOp, ctx:ScheduleItemContext) -> UOp:
sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right)
ret = graph_rewrite(sink, to_si, ctx)
PROCESS_REPLAY_CAPTURE.append((base_sink, ctx, ret))
PROCESS_REPLAY_CAPTURE.append((base_sink, ScheduleItemContext(ctx.var_vals), ret))
return ret
if getenv("RUN_PROCESS_REPLAY"):
@@ -198,7 +198,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val
metadata: Dict[UOp, Metadata] = {}
sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(),
to_uop(out, outs, inputs, buf_uops, metadata, cache)) for out in outs))
sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals, set(), []))
sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals))
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \