fix beam none if buf is optimized out (#10388)

This commit is contained in:
George Hotz
2025-05-17 21:50:33 -07:00
committed by GitHub
parent 6f77b938d7
commit 305a3231c4

View File

@@ -84,7 +84,7 @@ def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tup
# workers should ignore ctrl c
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs]
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() if buf is not None else buf for buf in bufs]
# *** external API ***
@@ -93,14 +93,16 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
for x in lin.bufs:
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
rawbufs: list[Optional[Buffer]] = [None]*len(bufsts)
# TODO: Nones are staying in here if buffers are optimized out!
# TODO: add a test for this
rawbufs: list[Optional[Buffer]] = [None]*(max(bufsts)+1)
for k,lx in bufsts.items():
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
assert isinstance(dtype, (PtrDType, ImageDType))
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
assert all(r is not None for r in rawbufs)
#assert all(r is not None for r in rawbufs)
return cast(list[Buffer], rawbufs)
# get dictionary of all possible actions