becomes_map cleanups [pr] (#9790)

* cleanup becomes_map [pr]

* source
This commit is contained in:
qazal
2025-04-09 14:11:53 +08:00
committed by GitHub
parent 7d2349c827
commit f27dbc8c35

View File

@@ -408,16 +408,17 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
# map tensors to buffer/const, optionally apply a VIEW on top
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
# ASSIGN always becomes the target buffer
if v.op is Ops.ASSIGN: becomes_map[k] = v.src[0]
# if we created a new buffer for this tensor, map it to the assigned buffer
elif (a:=kernel_map.get(v.base)) is not None and (a:=a.base).op is Ops.ASSIGN:
becomes_map[k] = a.src[0] if a.src[0].st == v.st else a.src[0].view(unwrap(v.st))
# tensors can also simplify to an existing buffer/const
else:
if k is v: continue
if v.base.op is Ops.BUFFER: becomes_map[k] = v
if v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
if (kernel:=kernel_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
if k is v: continue
if k.op is Ops.ASSIGN:
becomes_map[k] = k.src[0]
continue
op = v.base.op
if op is Ops.BUFFER: becomes_map[k] = v
if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
if op is Ops.ASSIGN:
new_buf = v.base.src[0]
becomes_map[k] = new_buf if new_buf.st == v.st else new_buf.view(unwrap(v.st))
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}