mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
fix beam none if buf is optimized out (#10388)
This commit is contained in:
@@ -84,7 +84,7 @@ def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tup
|
||||
# workers should ignore ctrl c
|
||||
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs]
|
||||
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() if buf is not None else buf for buf in bufs]
|
||||
|
||||
# *** external API ***
|
||||
|
||||
@@ -93,14 +93,16 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
|
||||
bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
|
||||
for x in lin.bufs:
|
||||
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
|
||||
rawbufs: list[Optional[Buffer]] = [None]*len(bufsts)
|
||||
# TODO: Nones are staying in here if buffers are optimized out!
|
||||
# TODO: add a test for this
|
||||
rawbufs: list[Optional[Buffer]] = [None]*(max(bufsts)+1)
|
||||
for k,lx in bufsts.items():
|
||||
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
|
||||
assert isinstance(dtype, (PtrDType, ImageDType))
|
||||
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
|
||||
buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
|
||||
rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
#assert all(r is not None for r in rawbufs)
|
||||
return cast(list[Buffer], rawbufs)
|
||||
|
||||
# get dictionary of all possible actions
|
||||
|
||||
Reference in New Issue
Block a user