From f2bee341970f718298d020d7b38cb33239c67b73 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 2 Jan 2025 13:13:16 +0200 Subject: [PATCH] tests for symbolic_simple failing tensor const spec [pr] (#8469) * tests for symbolic_simple failing tensor const spec [pr] * mul is correct --- test/test_schedule.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/test_schedule.py b/test/test_schedule.py index 4323d54839..a900995a5d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2048,6 +2048,26 @@ class TestBigGraph(unittest.TestCase): sink = tensor_rewrite(a) assert UPat.cvar(dtype=dtypes.int).match(sink, {}) + def test_const_folding_mul(self): + a = Tensor([1]) + sink = tensor_rewrite(a*0) + assert UPat(Ops.CONST, arg=0).match(sink, {}), f"expected {sink} to collapse to a const 0" + assert sink.shape == a.shape + + @unittest.expectedFailure + def test_const_folding_ne(self): + a = Tensor([1]) + sink = tensor_rewrite(a != a) + assert UPat(Ops.CONST, arg=False).match(sink, {}), f"expected {sink} to collapse to a const False" + assert sink.shape == a.shape + + @unittest.expectedFailure + def test_const_folding_lt(self): + a = Tensor([1]) + sink = tensor_rewrite(a < a) + assert UPat(Ops.CONST, arg=False).match(sink, {}), f"expected {sink} to collapse to a const False" + assert sink.shape == a.shape + tensor_const_pm = PatternMatcher([ (UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True), (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST))))), lambda: True),