tests for symbolic_simple failing tensor const spec [pr] (#8469)

* tests for symbolic_simple failing tensor const spec [pr]

* mul is correct
This commit is contained in:
qazal
2025-01-02 13:13:16 +02:00
committed by GitHub
parent dc9af4e2fc
commit f2bee34197

View File

@@ -2048,6 +2048,26 @@ class TestBigGraph(unittest.TestCase):
sink = tensor_rewrite(a)
assert UPat.cvar(dtype=dtypes.int).match(sink, {})
def test_const_folding_mul(self):
a = Tensor([1])
sink = tensor_rewrite(a*0)
assert UPat(Ops.CONST, arg=0).match(sink, {}), f"expected {sink} to collapse to a const 0"
assert sink.shape == a.shape
@unittest.expectedFailure
def test_const_folding_ne(self):
a = Tensor([1])
sink = tensor_rewrite(a != a)
assert UPat(Ops.CONST, arg=False).match(sink, {}), f"expected {sink} to collapse to a const False"
assert sink.shape == a.shape
@unittest.expectedFailure
def test_const_folding_lt(self):
a = Tensor([1])
sink = tensor_rewrite(a < a)
assert UPat(Ops.CONST, arg=False).match(sink, {}), f"expected {sink} to collapse to a const False"
assert sink.shape == a.shape
tensor_const_pm = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST))))), lambda: True),