This commit is contained in:
George Hotz
2025-03-17 16:06:37 +08:00
committed by GitHub
parent 52ae9af4dd
commit 242daa4f9a
2 changed files with 6 additions and 6 deletions

View File

@@ -45,7 +45,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
global_offset += len(grp)
assert None not in idxs, f"some idxs are missing {idxs}"
# this base thing is for image, we want the CAT to be a normal pointer
post_cat = UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)) if len(ret) > 1 else ret[0]
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret))
return post_cat.gep(tuple(cast(list[int], idxs)))
def cat_after_store(cat:UOp, data:UOp):
@@ -74,11 +74,11 @@ load_store_folding = PatternMatcher([
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
# GEP on data of STORE
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
# put CAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.CAT, name="cat"),), name="ld", allow_any_len=True),
# put PTRCAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
# put CAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.CAT, name="cat"), UPat(name="data"))), cat_after_store),
# put PTRCAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data"))), cat_after_store),
])
# ***** image load valid simplification *****

View File

@@ -117,7 +117,7 @@ class Ops(FastEnum):
REDUCE_AXIS = auto()
# helper ops
GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702