mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
ptrcat (#9473)
This commit is contained in:
@@ -45,7 +45,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
global_offset += len(grp)
|
||||
assert None not in idxs, f"some idxs are missing {idxs}"
|
||||
# this base thing is for image, we want the CAT to be a normal pointer
|
||||
post_cat = UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)) if len(ret) > 1 else ret[0]
|
||||
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret))
|
||||
return post_cat.gep(tuple(cast(list[int], idxs)))
|
||||
|
||||
def cat_after_store(cat:UOp, data:UOp):
|
||||
@@ -74,11 +74,11 @@ load_store_folding = PatternMatcher([
|
||||
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
||||
# GEP on data of STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
|
||||
# put CAT after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.CAT, name="cat"),), name="ld", allow_any_len=True),
|
||||
# put PTRCAT after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
||||
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||
# put CAT after STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.CAT, name="cat"), UPat(name="data"))), cat_after_store),
|
||||
# put PTRCAT after STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data"))), cat_after_store),
|
||||
])
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
@@ -117,7 +117,7 @@ class Ops(FastEnum):
|
||||
REDUCE_AXIS = auto()
|
||||
|
||||
# helper ops
|
||||
GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
|
||||
GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
||||
|
||||
# UnaryOps
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
|
||||
Reference in New Issue
Block a user