mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Fold constant variable (#10196)
* Add rule * add test and comment * merge rule
This commit is contained in:
@@ -110,6 +110,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_sub_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(a+-1)")
|
||||
|
||||
def test_const_var(self):
|
||||
self.helper_test_variable(Variable("fake", 1, 1), 1, 1, "1")
|
||||
|
||||
def test_add_self(self):
|
||||
a = Variable("a", 0, 8)
|
||||
b = Variable("b", 0, 8)
|
||||
|
||||
@@ -244,8 +244,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
# alu of two where with same conds can combine, only do if true branch or false branch is const
|
||||
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
|
||||
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
||||
# ALU min==max -> CONST (slow!)
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# ALU/variable min==max -> CONST (slow!)
|
||||
(UPat(GroupOp.ALU|{Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
# TODO: why does this rule break beautiful_mnist?
|
||||
|
||||
Reference in New Issue
Block a user