mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
more
This commit is contained in:
@@ -721,7 +721,7 @@ class TestExpander(unittest.TestCase):
|
||||
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
|
||||
|
||||
def test_contract_no_expand(self):
|
||||
e1 = UOp(Ops.DEFINE_VAR, dtypes.int)
|
||||
e1 = UOp.variable("i", 0, 10, dtype=dtypes.int)
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is Ops.VECTORIZE and len(sink.src) == 2
|
||||
|
||||
@@ -178,11 +178,6 @@ full_spec = PatternMatcher([
|
||||
# NOOP in the full spec
|
||||
(UPat(Ops.NOOP), lambda: True),
|
||||
|
||||
# Invalid must have type Index
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||
# where on index in rhs position is fine
|
||||
(UPat(Ops.WHERE, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.index))), lambda: True),
|
||||
|
||||
# all rewrite error are okay
|
||||
(UPat(Ops.REWRITE_ERROR), lambda: True),
|
||||
|
||||
@@ -205,8 +200,12 @@ full_spec = PatternMatcher([
|
||||
# linearizer: outputs + intermediate KERNELs
|
||||
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# Invalid must have type Index
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||
# where on index in rhs position is fine
|
||||
(UPat(Ops.WHERE, dtype=dtypes.index, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.index))), lambda: True),
|
||||
# allow index dtype on a restricted set of UOps
|
||||
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, Ops.WHERE,
|
||||
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX,
|
||||
Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.index), lambda: True),
|
||||
|
||||
# while BIND is being casted
|
||||
@@ -217,8 +216,8 @@ full_spec = PatternMatcher([
|
||||
|
||||
# all loads/stores
|
||||
(UPat((Ops.LOAD, Ops.STORE)), lambda: True),
|
||||
# all DEFINE_VAR to deal with the floats used in reduce collapse
|
||||
(UPat(Ops.DEFINE_VAR), lambda: True),
|
||||
# DEFINE_VAR to deal with the floats used in reduce collapse
|
||||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.floats), lambda: True),
|
||||
# allow any AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
])+tensor_spec+kernel_spec+program_spec+shared_spec
|
||||
|
||||
Reference in New Issue
Block a user