diff --git a/tinygrad/ops.py b/tinygrad/ops.py index dff11e89a0..5ed433f940 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -954,8 +954,8 @@ spec = PatternMatcher([ (UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True), (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), - # all WMMA has 3 args, - (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True), + # WMMA has a + (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8), (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),