diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 36312649ae..7734fb2d93 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -372,12 +372,12 @@ class ScheduleItem: bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] -def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem: - assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN" +def schedule_uop(k:UOp, var_vals:dict[Variable, int]) -> ScheduleItem: + assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}" # substitute kernel sources for the target buffer - ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink() + ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink() # add buffer ops - ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True) + ast = graph_rewrite(ast, add_buffer_ops, bufs:=[s.buf_uop for s in k.src], bottom_up=True) if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}") # unbind_vars + push views to edges ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right) @@ -385,7 +385,7 @@ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem: ast = graph_rewrite(ast, fix_kernel_ops, var_vals) # create subbuffer if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) - return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata) + return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), k.arg.metadata) PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {} if CAPTURE_PROCESS_REPLAY: @@ -459,7 +459,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va var_vals: dict[Variable, int] = {} while queue: u = queue.popleft() - schedule.append(schedule_uop(u, var_vals)) + schedule.append(schedule_uop(u.src[1], var_vals)) # increment the refcount of the target buf (this is required by the JIT and memory planner) u.buf_uop.buffer.ref(1) for x in children.get(u, []): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7ae4e645fb..f2576f2598 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -513,6 +513,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None @property def buf_uop(self) -> UOp: + if self.op is Ops.BUFFER: return self assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}" return self.src[0].base @property