noop changes from the block_assign branch [pr] (#8606)

This commit is contained in:
qazal
2025-01-14 07:47:17 -05:00
committed by GitHub
parent 5aab2806f0
commit 97ec564b03

View File

@@ -197,10 +197,13 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
to_si = PatternMatcher([
# BUFFER -> DEFINE_GLOBAL
(UPat(Ops.BUFFER, name="x"), _append_buf),
# simplify and unbind the final VIEWs
(UPat(Ops.VIEW, name="x"), _append_st_vars),
# don't need SINK on COPY or BUFFER_VIEW
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
@@ -209,9 +212,6 @@ to_si = PatternMatcher([
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
])
add_assign_adjacents = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x"), lambda ctx,b,x: ctx.assign_adj.setdefault(b, []).append(x)
if b in ctx.assigns else None)])
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
@@ -219,7 +219,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
# remove movement ops + substitute LOAD of fused STORE with just the value
sink = graph_rewrite(graph_rewrite(pre, multioutput+view_left, store_bufs:={x.buf_uop:x.src[2] for x in pre.src}), view_right)
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
ast = graph_rewrite(sink, to_si+append_bufs, si_ctx:=ScheduleItemContext(ctx.var_vals))
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals))
# deal with ASSIGN
assign_preloads: list[UOp] = []
if len(ctx.assigns) != 0:
@@ -540,12 +540,13 @@ def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) ->
# preschedule realize groups
prescheduled: list[ScheduleItem] = []
for store_uops in store_groups:
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) != 0:
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
# can only schedule once
for buf_uop in store_uops:
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
# do BFS
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) == 0: continue
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
# can only schedule once
for buf_uop in store_uops:
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
# add kernel children
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
@@ -560,6 +561,8 @@ def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) ->
for x in scheduled_parents:
graph[x].append(si)
in_degree[si] += 1
# do BFS
queue = deque(si for si in prescheduled if in_degree[si] == 0)
schedule: list[ScheduleItem] = []
while queue: