generate the individual indexes (#9587)

This commit is contained in:
George Hotz
2025-03-26 22:32:06 +08:00
committed by GitHub
parent 5c6cd884e3
commit 8aaa5e1ec5

View File

@@ -13,18 +13,18 @@ from tinygrad.renderer import Renderer
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
# first, extract all the relevant offsets
# generate the individual indexes
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
symbolic+load_store_indexing, name=f"index_buf_{buf.arg}")
# extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
midx, mmask = graph_rewrite(UOp.sink(UOp.sink(*[vec.gep(i) for i in range(vec.dtype.count)]),
UOp.sink(*[mask.gep(i) for i in range(vec.dtype.count)]) if mask is not None else UOp(Ops.NOOP)),
symbolic, name=f"index_buf_{buf.arg}").src
for i in range(vec.dtype.count):
idx: Any = midx.src[i]
idx: Any = midx.src[i].src[1]
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
if mask is not None: root_src = (mmask.src[i], root_src)
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
# the buf.dtype is always a pointer
@@ -34,12 +34,11 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
ret = []
idxs: list[int|None] = [None]*vec.dtype.count
global_offset = 0
for rootsrc, offsets in offsets_rootsrc.items():
for offsets in offsets_rootsrc.values():
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for grp in grouped_offsets:
# get the index offset for this element. using [0] is okay, because they are the same
oidx = vec.gep(offsets[grp[0]][0])
lidx = UOp(Ops.INDEX, buf.dtype, (buf, oidx, rootsrc[0]) if mask is not None else (buf, oidx))
lidx = midx.src[offsets[grp[0]][0]]
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
# set the idxs of the output
for i,g in enumerate(grp):