Fold constant variable (#10196)

* Add rule

* add test and comment

* merge rule
This commit is contained in:
Sieds Lykles
2025-05-07 20:39:44 +02:00
committed by GitHub
parent 8386527bb9
commit 2891892834
2 changed files with 5 additions and 2 deletions

View File

@@ -110,6 +110,9 @@ class TestSymbolic(unittest.TestCase):
def test_sub_1(self):
self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(a+-1)")
def test_const_var(self):
self.helper_test_variable(Variable("fake", 1, 1), 1, 1, "1")
def test_add_self(self):
a = Variable("a", 0, 8)
b = Variable("b", 0, 8)

View File

@@ -244,8 +244,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
# alu of two where with same conds can combine, only do if true branch or false branch is const
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# ALU min==max -> CONST (slow!)
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# ALU/variable min==max -> CONST (slow!)
(UPat(GroupOp.ALU|{Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
# TODO: why does this rule break beautiful_mnist?