refactor VECTORIZE/GEP rules (#5507)

This commit is contained in:
gswangg
2024-07-16 09:41:23 -07:00
committed by GitHub
parent 173064c69c
commit 203161c75d

View File

@@ -336,9 +336,8 @@ constant_folder = PatternMatcher([
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg)),
# VECTORIZE/GEP
(UOp(UOps.GEP, src=(UOp(UOps.VECTORIZE).name("cast"),)).name("gep"), lambda gep, cast: cast.src[gep.arg]),
(UOp(UOps.VECTORIZE, dtypes.float.vec(2), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(2))), lambda x: x),
(UOp(UOps.VECTORIZE, dtypes.float.vec(4), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(4))), lambda x: x),
(UOp(UOps.VECTORIZE, dtypes.float.vec(8), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(8))), lambda x: x),
*[(UOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=j)
for j in range(i))), lambda x: x) for i in [2, 4, 8]],
# tensor core with a 0 input is acc
(UOp(UOps.WMMA, src=(UOp.const(None, 0.0), UOp.var(), UOp.var('acc'))), lambda acc: acc),
(UOp(UOps.WMMA, src=(UOp.var(), UOp.const(None, 0.0), UOp.var('acc'))), lambda acc: acc),