minor arg spec check on wmma (#8525)

This commit is contained in:
ignaciosica
2025-01-10 20:42:56 -03:00
committed by GitHub
parent d09897c2aa
commit 8891495996

View File

@@ -954,8 +954,8 @@ spec = PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
# all WMMA has 3 args, <x, w, acc>
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
# WMMA has a <a, b, acc>
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),