mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
refactor to kernel ast fixup [pr] (#9376)
This commit is contained in:
@@ -372,12 +372,12 @@ class ScheduleItem:
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
|
||||
def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
||||
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
|
||||
def schedule_uop(k:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
||||
assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}"
|
||||
# substitute kernel sources for the target buffer
|
||||
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
|
||||
ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink()
|
||||
# add buffer ops
|
||||
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
|
||||
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[s.buf_uop for s in k.src], bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
# unbind_vars + push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
@@ -385,7 +385,7 @@ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
||||
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
||||
# create subbuffer
|
||||
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
|
||||
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), k.arg.metadata)
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
@@ -459,7 +459,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
var_vals: dict[Variable, int] = {}
|
||||
while queue:
|
||||
u = queue.popleft()
|
||||
schedule.append(schedule_uop(u, var_vals))
|
||||
schedule.append(schedule_uop(u.src[1], var_vals))
|
||||
# increment the refcount of the target buf (this is required by the JIT and memory planner)
|
||||
u.buf_uop.buffer.ref(1)
|
||||
for x in children.get(u, []):
|
||||
|
||||
@@ -513,6 +513,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
if self.op is Ops.BUFFER: return self
|
||||
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
|
||||
return self.src[0].base
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user