mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
refactor to ScheduleItemContext [pr] (#7217)
This commit is contained in:
@@ -124,16 +124,24 @@ view_right = merge_views+PatternMatcher([
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
def _append_st_vars(ctx:Tuple[Dict[Variable, int], Set[ShapeTracker], Tuple[int, ...]], x:UOp) -> Optional[UOp]:
|
||||
if (st:=unwrap(x.st)) in ctx[1]: return None
|
||||
# ** ScheduleItem context builder
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
var_vals: Dict[Variable, int]
|
||||
sts: Set[ShapeTracker]
|
||||
bufs: Tuple[int, ...]
|
||||
|
||||
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
||||
if (st:=unwrap(x.st)) in ctx.sts: return None
|
||||
st, var_vals = st.simplify().unbind()
|
||||
ctx[0].update(var_vals)
|
||||
ctx[1].add(st)
|
||||
ctx.var_vals.update(var_vals)
|
||||
ctx.sts.add(st)
|
||||
return st.to_uop() if st != x.st else None
|
||||
append_st_vars = PatternMatcher([(UPat(UOps.VIEW, name="x"), _append_st_vars)])
|
||||
|
||||
def _append_buf(ctx:Tuple[Dict[Variable, int], Set[ShapeTracker], Tuple[int, ...]], x:UOp) -> UOp:
|
||||
return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx[2].index(x.arg[0]))
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.bufs.index(x.arg[0]))
|
||||
append_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), _append_buf),])
|
||||
|
||||
to_ast = PatternMatcher([
|
||||
@@ -144,7 +152,7 @@ to_ast = PatternMatcher([
|
||||
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, Tuple[int, ...], UOp]] = []
|
||||
def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...], var_vals:Dict[Variable, int]) -> UOp:
|
||||
sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right)
|
||||
ret = graph_rewrite(graph_rewrite(sink, to_ast), append_st_vars+append_bufs, (var_vals, set(), bufs))
|
||||
ret = graph_rewrite(graph_rewrite(sink, to_ast), append_st_vars+append_bufs, ScheduleItemContext(var_vals, set(), bufs))
|
||||
PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret))
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user