refactor to ScheduleItemContext [pr] (#7217)

This commit is contained in:
qazal
2024-10-22 17:58:06 +03:00
committed by GitHub
parent 7ce12a4b06
commit 24ed2ed6c8

View File

@@ -124,16 +124,24 @@ view_right = merge_views+PatternMatcher([
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
def _append_st_vars(ctx:Tuple[Dict[Variable, int], Set[ShapeTracker], Tuple[int, ...]], x:UOp) -> Optional[UOp]:
if (st:=unwrap(x.st)) in ctx[1]: return None
# ** ScheduleItem context builder
@dataclass(frozen=True)
class ScheduleItemContext:
var_vals: Dict[Variable, int]
sts: Set[ShapeTracker]
bufs: Tuple[int, ...]
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
if (st:=unwrap(x.st)) in ctx.sts: return None
st, var_vals = st.simplify().unbind()
ctx[0].update(var_vals)
ctx[1].add(st)
ctx.var_vals.update(var_vals)
ctx.sts.add(st)
return st.to_uop() if st != x.st else None
append_st_vars = PatternMatcher([(UPat(UOps.VIEW, name="x"), _append_st_vars)])
def _append_buf(ctx:Tuple[Dict[Variable, int], Set[ShapeTracker], Tuple[int, ...]], x:UOp) -> UOp:
return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx[2].index(x.arg[0]))
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.bufs.index(x.arg[0]))
append_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), _append_buf),])
to_ast = PatternMatcher([
@@ -144,7 +152,7 @@ to_ast = PatternMatcher([
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, Tuple[int, ...], UOp]] = []
def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...], var_vals:Dict[Variable, int]) -> UOp:
sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right)
ret = graph_rewrite(graph_rewrite(sink, to_ast), append_st_vars+append_bufs, (var_vals, set(), bufs))
ret = graph_rewrite(graph_rewrite(sink, to_ast), append_st_vars+append_bufs, ScheduleItemContext(var_vals, set(), bufs))
PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret))
return ret