mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
schedule linearize small cleanups [pr] (#9994)
This commit is contained in:
@@ -52,10 +52,11 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
var_vals: dict[Variable, int] = {}
|
||||
while queue:
|
||||
k = queue.popleft()
|
||||
# map the BUFFER UOp to a subbuffer if it's a BUFFER_VIEW
|
||||
if k.arg.ast.op is Ops.BUFFER_VIEW:
|
||||
buffers[k.src[0]] = (base:=k.src[1].buf_uop.buffer).view(k.size, k.arg.ast.dtype, k.arg.ast.arg[1]*base.dtype.itemsize)
|
||||
schedule.append(ScheduleItem(graph_rewrite(k.arg.ast, pm_unbind, ctx=var_vals), tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
|
||||
# unbind var_vals from the kernel
|
||||
ast = graph_rewrite(k.arg.ast, pm_unbind, ctx=var_vals)
|
||||
# create subbuffers if needed
|
||||
if ast.op is Ops.BUFFER_VIEW: buffers[k.src[0]] = (base:=k.src[1].buf_uop.buffer).view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
schedule.append(ScheduleItem(ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
|
||||
for x in children[k]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
Reference in New Issue
Block a user