From 8aaa5e1ec532240ae0458f0b61a8d85074d46deb Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 26 Mar 2025 22:32:06 +0800 Subject: [PATCH] generate the individual indexes (#9587) --- tinygrad/codegen/devectorizer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index c7b95fd9fd..e2b718288e 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -13,18 +13,18 @@ from tinygrad.renderer import Renderer def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): if getenv("UNSAFE_DISABLE_MASK", 0): mask = None - # first, extract all the relevant offsets + # generate the individual indexes + midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]), + symbolic+load_store_indexing, name=f"index_buf_{buf.arg}") + # extract all the relevant offsets offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict) - midx, mmask = graph_rewrite(UOp.sink(UOp.sink(*[vec.gep(i) for i in range(vec.dtype.count)]), - UOp.sink(*[mask.gep(i) for i in range(vec.dtype.count)]) if mask is not None else UOp(Ops.NOOP)), - symbolic, name=f"index_buf_{buf.arg}").src for i in range(vec.dtype.count): - idx: Any = midx.src[i] + idx: Any = midx.src[i].src[1] if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg else: root_src, arg = idx, 0 - if mask is not None: root_src = (mmask.src[i], root_src) + if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src) offsets_rootsrc[root_src].setdefault(arg, []).append(i) # the buf.dtype is always a pointer @@ -34,12 +34,11 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): ret = [] idxs: list[int|None] = [None]*vec.dtype.count global_offset = 0 - for rootsrc, offsets in offsets_rootsrc.items(): + for offsets in offsets_rootsrc.values(): grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])] for grp in grouped_offsets: # get the index offset for this element. using [0] is okay, because they are the same - oidx = vec.gep(offsets[grp[0]][0]) - lidx = UOp(Ops.INDEX, buf.dtype, (buf, oidx, rootsrc[0]) if mask is not None else (buf, oidx)) + lidx = midx.src[offsets[grp[0]][0]] if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local)) # set the idxs of the output for i,g in enumerate(grp):