diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 144d37833f..a67a1c1d7c 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -110,6 +110,9 @@ class TestSymbolic(unittest.TestCase): def test_sub_1(self): self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(a+-1)") + def test_const_var(self): + self.helper_test_variable(Variable("fake", 1, 1), 1, 1, "1") + def test_add_self(self): a = Variable("a", 0, 8) b = Variable("b", 0, 8) diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index c0a38ce1b4..c354d484b1 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -244,8 +244,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ # alu of two where with same conds can combine, only do if true branch or false branch is const (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \ lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), - # ALU min==max -> CONST (slow!) - (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), + # ALU/variable min==max -> CONST (slow!) + (UPat(GroupOp.ALU|{Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), # TODO: why does this rule break beautiful_mnist?