This commit is contained in:
George Hotz
2025-10-28 09:41:58 +08:00
parent 7d26342ab6
commit 77aadcb01d
2 changed files with 8 additions and 9 deletions

View File

@@ -721,7 +721,7 @@ class TestExpander(unittest.TestCase):
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
def test_contract_no_expand(self):
e1 = UOp(Ops.DEFINE_VAR, dtypes.int)
e1 = UOp.variable("i", 0, 10, dtype=dtypes.int)
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is Ops.VECTORIZE and len(sink.src) == 2

View File

@@ -178,11 +178,6 @@ full_spec = PatternMatcher([
# NOOP in the full spec
(UPat(Ops.NOOP), lambda: True),
# Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
# where on index in rhs position is fine
(UPat(Ops.WHERE, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.index))), lambda: True),
# all rewrite error are okay
(UPat(Ops.REWRITE_ERROR), lambda: True),
@@ -205,8 +200,12 @@ full_spec = PatternMatcher([
# linearizer: outputs + intermediate KERNELs
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
# Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
# where on index in rhs position is fine
(UPat(Ops.WHERE, dtype=dtypes.index, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.index))), lambda: True),
# allow index dtype on a restricted set of UOps
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, Ops.WHERE,
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX,
Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.index), lambda: True),
# while BIND is being casted
@@ -217,8 +216,8 @@ full_spec = PatternMatcher([
# all loads/stores
(UPat((Ops.LOAD, Ops.STORE)), lambda: True),
# all DEFINE_VAR to deal with the floats used in reduce collapse
(UPat(Ops.DEFINE_VAR), lambda: True),
# DEFINE_VAR to deal with the floats used in reduce collapse
(UPat(Ops.DEFINE_VAR, dtype=dtypes.floats), lambda: True),
# allow any AFTER
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
])+tensor_spec+kernel_spec+program_spec+shared_spec