refactor to kernel ast fixup [pr] (#9376)

This commit is contained in:
qazal
2025-03-07 16:47:38 +02:00
committed by GitHub
parent 304afe0d55
commit 3565c08df5
2 changed files with 7 additions and 6 deletions

View File

@@ -372,12 +372,12 @@ class ScheduleItem:
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
def schedule_uop(k:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}"
# substitute kernel sources for the target buffer
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink()
# add buffer ops
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[s.buf_uop for s in k.src], bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
# unbind_vars + push views to edges
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
@@ -385,7 +385,7 @@ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
# create subbuffer
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), k.arg.metadata)
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@@ -459,7 +459,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
var_vals: dict[Variable, int] = {}
while queue:
u = queue.popleft()
schedule.append(schedule_uop(u, var_vals))
schedule.append(schedule_uop(u.src[1], var_vals))
# increment the refcount of the target buf (this is required by the JIT and memory planner)
u.buf_uop.buffer.ref(1)
for x in children.get(u, []):

View File

@@ -513,6 +513,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
return self.src[0].base
@property