mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update no_vectorized_index [pr] (#15313)
combine no_vectorized_index and no_vectorized_index_broadcast
This commit is contained in:
@@ -251,32 +251,27 @@ def no_vectorized_alu(alu:UOp):
|
||||
def no_vectorized_buf(buf:UOp):
|
||||
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype)
|
||||
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
|
||||
cnt = cast.dtype.count
|
||||
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))), ptr=True)
|
||||
|
||||
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
vcnt = cast.dtype.vcount
|
||||
precnt = bcast.dtype.vcount
|
||||
# TODO: I have no idea *why* this is. I just change things until the tests pass. No AI, old school.
|
||||
if bcast.op is Ops.GEP:
|
||||
gep_arg = tuple(flatten([range(precnt) for _ in range(vcnt)]))
|
||||
sum_arg = tuple(flatten([[i+y for y in bcast.arg] for i in range(vcnt)]))
|
||||
if bcast is not None and bcast.op is Ops.GEP:
|
||||
# GEP selects specific lanes; bcast.arg[k] is the offset for lane k, iterate groups × selected lanes
|
||||
pairs = [(k, g + bcast.arg[k]) for g, k in itertools.product(range(cast.dtype.vcount), range(len(bcast.arg)))]
|
||||
elif bcast is not None:
|
||||
# BROADCAST: cross product of components × lanes
|
||||
pairs = [(j, c) for c, j in itertools.product(range(cnt), range(bcast.dtype.vcount))]
|
||||
else:
|
||||
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
|
||||
sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)]))
|
||||
new_idx = idx.gep(gep_arg)*cnt + UOp.const(dtypes.index.vec(len(sum_arg)), sum_arg)
|
||||
return buf.broadcast(cnt*precnt).index(new_idx, ptr=True)
|
||||
# simple scalar index: one lane, all components
|
||||
pairs = [(0, c) for c in range(cnt)]
|
||||
idx_lanes, offsets = (tuple(x) for x in zip(*pairs))
|
||||
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.index.vec(len(pairs)), offsets), ptr=True)
|
||||
|
||||
devectorize_buf_and_index = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index_broadcast),
|
||||
no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index_broadcast),
|
||||
no_vectorized_index),
|
||||
])
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user