mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
rename CAT to VCAT (#15167)
This commit is contained in:
@@ -756,7 +756,7 @@ class TestLoadStoreFolding(unittest.TestCase):
|
||||
self.assertEqual(len(gated_load.src), 2) # PTRCAT + alt
|
||||
result = graph_rewrite(gated_load, load_store_folding, name='test')
|
||||
# After rewrite, should be CAT of LOADs, each preserving alt
|
||||
self.assertEqual(result.op, Ops.CAT)
|
||||
self.assertEqual(result.op, Ops.VCAT)
|
||||
for inner_load in result.src:
|
||||
self.assertEqual(inner_load.op, Ops.LOAD)
|
||||
self.assertEqual(len(inner_load.src), 2) # INDEX + alt
|
||||
|
||||
@@ -130,7 +130,7 @@ load_store_folding = PatternMatcher([
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), name="sto"), gep_on_store),
|
||||
# put PTRCAT after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
||||
lambda cat,ld: UOp(Ops.CAT, cat.dtype.base.vec(cat.dtype.vcount), tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||
lambda cat,ld: UOp(Ops.VCAT, cat.dtype.base.vec(cat.dtype.vcount), tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||
# put PTRCAT after STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), name="sto"), cat_after_store),
|
||||
])
|
||||
@@ -181,7 +181,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
|
||||
# if it wasn't split, we return None. otherwise we CAT them
|
||||
if len(ret) <= 1: return None
|
||||
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
|
||||
return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
|
||||
|
||||
def _do_image_fixup(dt:ImageDType, idx:UOp) -> tuple[UOp, UOp, int, int]:
|
||||
buf = idx.src[0]
|
||||
|
||||
@@ -51,7 +51,7 @@ def do_expand(root:UOp):
|
||||
new_srcs.append(src)
|
||||
elif src.dtype.count > 1:
|
||||
# put any input dtype > 1 grouped together
|
||||
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
|
||||
new_srcs.append(UOp(Ops.VCAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
|
||||
else:
|
||||
# repeat the arg
|
||||
new_srcs.append(src.broadcast(expand_sz))
|
||||
|
||||
@@ -101,7 +101,7 @@ class Ops(FastEnum):
|
||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
|
||||
|
||||
# expander ops
|
||||
UNROLL = auto(); CONTRACT = auto(); CAT = auto(); PTRCAT = auto()
|
||||
UNROLL = auto(); CONTRACT = auto(); VCAT = auto(); PTRCAT = auto()
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}
|
||||
|
||||
@@ -260,7 +260,7 @@ full_spec = PatternMatcher([
|
||||
# PTRCAT is like VECTORIZE, but it functions on ptrs
|
||||
(UPat(Ops.PTRCAT, name="x"), lambda x: x.dtype.vcount == sum([y.dtype.base.count for y in x.src])),
|
||||
# CAT is like VECTORIZE, but the srcs can be vectors
|
||||
(UPat(Ops.CAT, name="x"), lambda x: x.dtype.vcount == sum([y.dtype.vcount for y in x.src])),
|
||||
(UPat(Ops.VCAT, name="x"), lambda x: x.dtype.vcount == sum([y.dtype.vcount for y in x.src])),
|
||||
# vectorized index
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ gep_pushing = PatternMatcher([
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) and not isinstance(alu.dtype, PtrDType) else None),
|
||||
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
||||
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
|
||||
(UPat(Ops.VCAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
|
||||
if not isinstance(x.dtype, PtrDType) else None),
|
||||
# VECTORIZE on same GEP
|
||||
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
|
||||
|
||||
Reference in New Issue
Block a user