rename CAT to VCAT (#15167)

This commit is contained in:
George Hotz
2026-03-06 18:46:28 +08:00
committed by GitHub
parent 059c6326c0
commit 6fd18ef875
6 changed files with 7 additions and 7 deletions

View File

@@ -756,7 +756,7 @@ class TestLoadStoreFolding(unittest.TestCase):
self.assertEqual(len(gated_load.src), 2) # PTRCAT + alt
result = graph_rewrite(gated_load, load_store_folding, name='test')
# After rewrite, should be CAT of LOADs, each preserving alt
self.assertEqual(result.op, Ops.CAT)
self.assertEqual(result.op, Ops.VCAT)
for inner_load in result.src:
self.assertEqual(inner_load.op, Ops.LOAD)
self.assertEqual(len(inner_load.src), 2) # INDEX + alt

View File

@@ -130,7 +130,7 @@ load_store_folding = PatternMatcher([
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), name="sto"), gep_on_store),
# put PTRCAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
lambda cat,ld: UOp(Ops.CAT, cat.dtype.base.vec(cat.dtype.vcount), tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
lambda cat,ld: UOp(Ops.VCAT, cat.dtype.base.vec(cat.dtype.vcount), tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
# put PTRCAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), name="sto"), cat_after_store),
])
@@ -181,7 +181,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# if it wasn't split, we return None. otherwise we CAT them
if len(ret) <= 1: return None
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
def _do_image_fixup(dt:ImageDType, idx:UOp) -> tuple[UOp, UOp, int, int]:
buf = idx.src[0]

View File

@@ -51,7 +51,7 @@ def do_expand(root:UOp):
new_srcs.append(src)
elif src.dtype.count > 1:
# put any input dtype > 1 grouped together
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
new_srcs.append(UOp(Ops.VCAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
else:
# repeat the arg
new_srcs.append(src.broadcast(expand_sz))

View File

@@ -101,7 +101,7 @@ class Ops(FastEnum):
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
# expander ops
UNROLL = auto(); CONTRACT = auto(); CAT = auto(); PTRCAT = auto()
UNROLL = auto(); CONTRACT = auto(); VCAT = auto(); PTRCAT = auto()
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}

View File

@@ -260,7 +260,7 @@ full_spec = PatternMatcher([
# PTRCAT is like VECTORIZE, but it functions on ptrs
(UPat(Ops.PTRCAT, name="x"), lambda x: x.dtype.vcount == sum([y.dtype.base.count for y in x.src])),
# CAT is like VECTORIZE, but the srcs can be vectors
(UPat(Ops.CAT, name="x"), lambda x: x.dtype.vcount == sum([y.dtype.vcount for y in x.src])),
(UPat(Ops.VCAT, name="x"), lambda x: x.dtype.vcount == sum([y.dtype.vcount for y in x.src])),
# vectorized index
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),

View File

@@ -175,7 +175,7 @@ gep_pushing = PatternMatcher([
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
if not isinstance(gep.dtype, PtrDType) and not isinstance(alu.dtype, PtrDType) else None),
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
(UPat(Ops.VCAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
if not isinstance(x.dtype, PtrDType) else None),
# VECTORIZE on same GEP
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),