diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 072094ae95..e8bc757f5e 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -251,32 +251,27 @@ def no_vectorized_alu(alu:UOp): def no_vectorized_buf(buf:UOp): return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype) -def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp): +def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None): cnt = cast.dtype.count - assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}" - return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))), ptr=True) - -def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp): - cnt = cast.dtype.count - vcnt = cast.dtype.vcount - precnt = bcast.dtype.vcount - # TODO: I have no idea *why* this is. I just change things until the tests pass. No AI, old school. - if bcast.op is Ops.GEP: - gep_arg = tuple(flatten([range(precnt) for _ in range(vcnt)])) - sum_arg = tuple(flatten([[i+y for y in bcast.arg] for i in range(vcnt)])) + if bcast is not None and bcast.op is Ops.GEP: + # GEP selects specific lanes; bcast.arg[k] is the offset for lane k, iterate groups × selected lanes + pairs = [(k, g + bcast.arg[k]) for g, k in itertools.product(range(cast.dtype.vcount), range(len(bcast.arg)))] + elif bcast is not None: + # BROADCAST: cross product of components × lanes + pairs = [(j, c) for c, j in itertools.product(range(cnt), range(bcast.dtype.vcount))] else: - gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)])) - sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)])) - new_idx = idx.gep(gep_arg)*cnt + UOp.const(dtypes.index.vec(len(sum_arg)), sum_arg) - return buf.broadcast(cnt*precnt).index(new_idx, ptr=True) + # simple scalar index: one lane, all components + pairs = [(0, c) for c in range(cnt)] + idx_lanes, offsets = (tuple(x) for x in zip(*pairs)) + return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.index.vec(len(pairs)), offsets), ptr=True) devectorize_buf_and_index = PatternMatcher([ (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf), (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index), (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")), - no_vectorized_index_broadcast), + no_vectorized_index), (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")), - no_vectorized_index_broadcast), + no_vectorized_index), ]) devectorize = PatternMatcher([