mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
noop changes from the block_assign branch [pr] (#8606)
This commit is contained in:
@@ -197,10 +197,13 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
ctx.bufs.append(x)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
|
||||
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
||||
|
||||
to_si = PatternMatcher([
|
||||
# BUFFER -> DEFINE_GLOBAL
|
||||
(UPat(Ops.BUFFER, name="x"), _append_buf),
|
||||
# simplify and unbind the final VIEWs
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
# don't need SINK on COPY or BUFFER_VIEW
|
||||
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
|
||||
# don't need contiguous or assign anymore
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
@@ -209,9 +212,6 @@ to_si = PatternMatcher([
|
||||
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
|
||||
])
|
||||
|
||||
add_assign_adjacents = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x"), lambda ctx,b,x: ctx.assign_adj.setdefault(b, []).append(x)
|
||||
if b in ctx.assigns else None)])
|
||||
|
||||
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
|
||||
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
|
||||
|
||||
@@ -219,7 +219,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
||||
# remove movement ops + substitute LOAD of fused STORE with just the value
|
||||
sink = graph_rewrite(graph_rewrite(pre, multioutput+view_left, store_bufs:={x.buf_uop:x.src[2] for x in pre.src}), view_right)
|
||||
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
||||
ast = graph_rewrite(sink, to_si+append_bufs, si_ctx:=ScheduleItemContext(ctx.var_vals))
|
||||
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals))
|
||||
# deal with ASSIGN
|
||||
assign_preloads: list[UOp] = []
|
||||
if len(ctx.assigns) != 0:
|
||||
@@ -540,12 +540,13 @@ def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) ->
|
||||
# preschedule realize groups
|
||||
prescheduled: list[ScheduleItem] = []
|
||||
for store_uops in store_groups:
|
||||
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) != 0:
|
||||
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
|
||||
# can only schedule once
|
||||
for buf_uop in store_uops:
|
||||
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
|
||||
# do BFS
|
||||
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) == 0: continue
|
||||
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
|
||||
# can only schedule once
|
||||
for buf_uop in store_uops:
|
||||
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
|
||||
|
||||
# add kernel children
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
|
||||
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
|
||||
@@ -560,6 +561,8 @@ def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) ->
|
||||
for x in scheduled_parents:
|
||||
graph[x].append(si)
|
||||
in_degree[si] += 1
|
||||
|
||||
# do BFS
|
||||
queue = deque(si for si in prescheduled if in_degree[si] == 0)
|
||||
schedule: list[ScheduleItem] = []
|
||||
while queue:
|
||||
|
||||
Reference in New Issue
Block a user