mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
UOp pattern x + x -> x * 2 (#6224)
* UOp pattern x + x -> x * 2 now there's no NEG, with this it covers all kinds of a*x+b*x * can remove x-x
This commit is contained in:
@@ -97,8 +97,13 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_sub_num_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, "(-1+a)")
|
||||
|
||||
def test_add_self(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a+a, 0, 16, "(a*2)")
|
||||
|
||||
def test_sub_self(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a-a, 0, 0, "0")
|
||||
self.helper_test_variable(a*3-a, 0, 16, "(a*2)")
|
||||
|
||||
def test_mul_0(self):
|
||||
|
||||
@@ -150,8 +150,13 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_sub_num_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, {"(-1+a)", "(a+(-1))"})
|
||||
|
||||
def test_add_self(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a+a, 0, 16, "(a*2)")
|
||||
|
||||
def test_sub_self(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a-a, 0, 0, "0")
|
||||
self.helper_test_variable(a*3-a, 0, 16, "(a*2)")
|
||||
|
||||
def test_mul_0(self):
|
||||
|
||||
@@ -269,7 +269,6 @@ constant_folder = PatternMatcher([
|
||||
# ** constant folding **
|
||||
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
||||
# ** self folding **
|
||||
(-(-NOp.var('x')), lambda x: x), # -(-x) -> x
|
||||
(NOp.var('x') + 0, lambda x: x), # x+0 -> x
|
||||
(NOp.var('x') * 1, lambda x: x), # x*1 -> x
|
||||
(NOp.var('x') // NOp.var('x'), lambda x: x.const(1)), # x//x -> 1
|
||||
@@ -283,8 +282,6 @@ constant_folder = PatternMatcher([
|
||||
# if x is nan or inf it should render the nan value.
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# x-x -> 0
|
||||
(NOp.var('x') - NOp.var('x'), lambda x: x.const(0)),
|
||||
# min==max -> CONST (slow!)
|
||||
(UPat({UOps.ALU, UOps.DEFINE_VAR}, name='x'), lambda x: x.const(x.vmin.arg) if x.vmin.arg == x.vmax.arg else None),
|
||||
# ** load/store folding **
|
||||
@@ -322,6 +319,10 @@ constant_folder = PatternMatcher([
|
||||
x*c1+c0.arg*c1.arg if dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
|
||||
# (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(NOp.var("x") * NOp.cvar("c0") + NOp.var("x") * NOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
|
||||
# (x+x*c)-> x*(c+1)
|
||||
(NOp.var("x") + NOp.var("x") * NOp.cvar("c"), lambda x,c: x*(c.arg+1)),
|
||||
# (x+x)-> x*2
|
||||
(NOp.var("x") + NOp.var("x"), lambda x: x*2),
|
||||
# (x*c0)+(y*c0) -> (x+y)*c0
|
||||
#((NOp.var("x") * NOp.cvar("c0")) + (NOp.var("y") * NOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
|
||||
# (x*x2)/x2 -> x
|
||||
@@ -332,8 +333,6 @@ constant_folder = PatternMatcher([
|
||||
((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
|
||||
# c0 + x < c1 -> x < c1 - c0
|
||||
((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
|
||||
# (x+x*c0)-> x*(c0+1)
|
||||
(NOp.var("x") + NOp.var("x") * NOp.cvar("c0"), lambda x,c0: x*(c0.arg+1)),
|
||||
# x!=0 -> (bool)x
|
||||
(NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
|
||||
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
||||
|
||||
Reference in New Issue
Block a user