mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
generate the individual indexes (#9587)
This commit is contained in:
@@ -13,18 +13,18 @@ from tinygrad.renderer import Renderer
|
||||
|
||||
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
|
||||
# first, extract all the relevant offsets
|
||||
# generate the individual indexes
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
||||
symbolic+load_store_indexing, name=f"index_buf_{buf.arg}")
|
||||
# extract all the relevant offsets
|
||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||
midx, mmask = graph_rewrite(UOp.sink(UOp.sink(*[vec.gep(i) for i in range(vec.dtype.count)]),
|
||||
UOp.sink(*[mask.gep(i) for i in range(vec.dtype.count)]) if mask is not None else UOp(Ops.NOOP)),
|
||||
symbolic, name=f"index_buf_{buf.arg}").src
|
||||
for i in range(vec.dtype.count):
|
||||
idx: Any = midx.src[i]
|
||||
idx: Any = midx.src[i].src[1]
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
if mask is not None: root_src = (mmask.src[i], root_src)
|
||||
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
|
||||
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
||||
|
||||
# the buf.dtype is always a pointer
|
||||
@@ -34,12 +34,11 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
ret = []
|
||||
idxs: list[int|None] = [None]*vec.dtype.count
|
||||
global_offset = 0
|
||||
for rootsrc, offsets in offsets_rootsrc.items():
|
||||
for offsets in offsets_rootsrc.values():
|
||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||
for grp in grouped_offsets:
|
||||
# get the index offset for this element. using [0] is okay, because they are the same
|
||||
oidx = vec.gep(offsets[grp[0]][0])
|
||||
lidx = UOp(Ops.INDEX, buf.dtype, (buf, oidx, rootsrc[0]) if mask is not None else (buf, oidx))
|
||||
lidx = midx.src[offsets[grp[0]][0]]
|
||||
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
# set the idxs of the output
|
||||
for i,g in enumerate(grp):
|
||||
|
||||
Reference in New Issue
Block a user