diff --git a/test/test_schedule.py b/test/test_schedule.py index 4323d54839..a900995a5d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2048,6 +2048,26 @@ class TestBigGraph(unittest.TestCase): sink = tensor_rewrite(a) assert UPat.cvar(dtype=dtypes.int).match(sink, {}) + def test_const_folding_mul(self): + a = Tensor([1]) + sink = tensor_rewrite(a*0) + assert UPat(Ops.CONST, arg=0).match(sink, {}), f"expected {sink} to collapse to a const 0" + assert sink.shape == a.shape + + @unittest.expectedFailure + def test_const_folding_ne(self): + a = Tensor([1]) + sink = tensor_rewrite(a != a) + assert UPat(Ops.CONST, arg=False).match(sink, {}), f"expected {sink} to collapse to a const False" + assert sink.shape == a.shape + + @unittest.expectedFailure + def test_const_folding_lt(self): + a = Tensor([1]) + sink = tensor_rewrite(a < a) + assert UPat(Ops.CONST, arg=False).match(sink, {}), f"expected {sink} to collapse to a const False" + assert sink.shape == a.shape + tensor_const_pm = PatternMatcher([ (UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True), (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST))))), lambda: True),