init changes from the global buffer branch [pr] (#7939)

This commit is contained in:
qazal
2024-11-28 06:38:58 -05:00
committed by GitHub
parent 81d415be03
commit 3ab67d45b2

View File

@@ -68,14 +68,16 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
assert buf.op is not None, f"base must be base itself {buf}"
dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base
if buf.is_realized:
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers))
buffers[ubuf] = buf.buffer
op = None
elif buf.op is Ops.ASSIGN:
target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs]
ctx.assigns.add(ubuf:=target.buf_uop)
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
else:
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers))
buffers[ubuf] = buf.buffer
op = UOp(buf.op, dtype, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
if op is not None: