update no_vectorized_index [pr] (#15313)

combine no_vectorized_index and no_vectorized_index_broadcast
This commit is contained in:
chenyu
2026-03-17 03:05:23 -04:00
committed by GitHub
parent 856a839efc
commit 6b6d1814ca

View File

@@ -251,32 +251,27 @@ def no_vectorized_alu(alu:UOp):
def no_vectorized_buf(buf:UOp):
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype)
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
cnt = cast.dtype.count
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))), ptr=True)
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
cnt = cast.dtype.count
vcnt = cast.dtype.vcount
precnt = bcast.dtype.vcount
# TODO: I have no idea *why* this is. I just change things until the tests pass. No AI, old school.
if bcast.op is Ops.GEP:
gep_arg = tuple(flatten([range(precnt) for _ in range(vcnt)]))
sum_arg = tuple(flatten([[i+y for y in bcast.arg] for i in range(vcnt)]))
if bcast is not None and bcast.op is Ops.GEP:
# GEP selects specific lanes; bcast.arg[k] is the offset for lane k, iterate groups × selected lanes
pairs = [(k, g + bcast.arg[k]) for g, k in itertools.product(range(cast.dtype.vcount), range(len(bcast.arg)))]
elif bcast is not None:
# BROADCAST: cross product of components × lanes
pairs = [(j, c) for c, j in itertools.product(range(cnt), range(bcast.dtype.vcount))]
else:
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)]))
new_idx = idx.gep(gep_arg)*cnt + UOp.const(dtypes.index.vec(len(sum_arg)), sum_arg)
return buf.broadcast(cnt*precnt).index(new_idx, ptr=True)
# simple scalar index: one lane, all components
pairs = [(0, c) for c in range(cnt)]
idx_lanes, offsets = (tuple(x) for x in zip(*pairs))
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.index.vec(len(pairs)), offsets), ptr=True)
devectorize_buf_and_index = PatternMatcher([
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
no_vectorized_index_broadcast),
no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
no_vectorized_index_broadcast),
no_vectorized_index),
])
devectorize = PatternMatcher([