diff --git a/test/null/test_uop_graph.py b/test/null/test_uop_graph.py index db4bfba1e8..6691091ac1 100644 --- a/test/null/test_uop_graph.py +++ b/test/null/test_uop_graph.py @@ -756,7 +756,7 @@ class TestLoadStoreFolding(unittest.TestCase): self.assertEqual(len(gated_load.src), 2) # PTRCAT + alt result = graph_rewrite(gated_load, load_store_folding, name='test') # After rewrite, should be CAT of LOADs, each preserving alt - self.assertEqual(result.op, Ops.CAT) + self.assertEqual(result.op, Ops.VCAT) for inner_load in result.src: self.assertEqual(inner_load.op, Ops.LOAD) self.assertEqual(len(inner_load.src), 2) # INDEX + alt diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 45c7e73546..f13b4f475f 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -130,7 +130,7 @@ load_store_folding = PatternMatcher([ (UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), name="sto"), gep_on_store), # put PTRCAT after LOAD (UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True), - lambda cat,ld: UOp(Ops.CAT, cat.dtype.base.vec(cat.dtype.vcount), tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))), + lambda cat,ld: UOp(Ops.VCAT, cat.dtype.base.vec(cat.dtype.vcount), tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))), # put PTRCAT after STORE (UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), name="sto"), cat_after_store), ]) @@ -181,7 +181,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): # if it wasn't split, we return None. otherwise we CAT them if len(ret) <= 1: return None - return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret) + return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret) def _do_image_fixup(dt:ImageDType, idx:UOp) -> tuple[UOp, UOp, int, int]: buf = idx.src[0] diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index b12dc147e7..9ccd700723 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -51,7 +51,7 @@ def do_expand(root:UOp): new_srcs.append(src) elif src.dtype.count > 1: # put any input dtype > 1 grouped together - new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz)) + new_srcs.append(UOp(Ops.VCAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz)) else: # repeat the arg new_srcs.append(src.broadcast(expand_sz)) diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 7772288073..2f8d618076 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -101,7 +101,7 @@ class Ops(FastEnum): REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # expander ops - UNROLL = auto(); CONTRACT = auto(); CAT = auto(); PTRCAT = auto() + UNROLL = auto(); CONTRACT = auto(); VCAT = auto(); PTRCAT = auto() class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC} diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 971b81763e..460e1385b3 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -260,7 +260,7 @@ full_spec = PatternMatcher([ # PTRCAT is like VECTORIZE, but it functions on ptrs (UPat(Ops.PTRCAT, name="x"), lambda x: x.dtype.vcount == sum([y.dtype.base.count for y in x.src])), # CAT is like VECTORIZE, but the srcs can be vectors - (UPat(Ops.CAT, name="x"), lambda x: x.dtype.vcount == sum([y.dtype.vcount for y in x.src])), + (UPat(Ops.VCAT, name="x"), lambda x: x.dtype.vcount == sum([y.dtype.vcount for y in x.src])), # vectorized index (UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True), diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index c48a8fd26a..1aeeb7fa58 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -175,7 +175,7 @@ gep_pushing = PatternMatcher([ lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \ if not isinstance(gep.dtype, PtrDType) and not isinstance(alu.dtype, PtrDType) else None), # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later) - (UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \ + (UPat(Ops.VCAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \ if not isinstance(x.dtype, PtrDType) else None), # VECTORIZE on same GEP (UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),