mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
update UOp mod reduction patterns (#5883)
prepare generic mod folding, also some test changes from mod folding pr
This commit is contained in:
@@ -239,8 +239,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_sum_num_hoisted_and_factors_cancel_out(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
|
||||
|
||||
def test_div_mod_factor(self):
|
||||
def test_div_cancel(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
|
||||
|
||||
def test_mod_cancel(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")
|
||||
|
||||
def test_mul_div(self):
|
||||
|
||||
@@ -303,8 +303,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_div_mod_factor(self):
|
||||
def test_div_cancel(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_mod_cancel(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")
|
||||
|
||||
def test_mul_div(self):
|
||||
|
||||
@@ -236,9 +236,9 @@ constant_folder = PatternMatcher([
|
||||
(((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')) // NOp.cvar('c1'), lambda x,x2,c0,c1:\
|
||||
x*(c0.arg//g)//(c1.arg//g) if c0.arg > 0 and c1.arg > 0 and (g:=math.gcd(c0.arg,c1.arg)) > 1 and g > x2.vmax.arg and x2.vmin.arg >= 0 else None),
|
||||
# ** mod **
|
||||
# mod folding and mod reduction
|
||||
(NOp.var('x') % NOp.cvar('c'), lambda x,c: x if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else \
|
||||
(x-(x.vmin.arg//c.arg)*c.arg)%c if 0 < c.arg <= x.vmin.arg else None),
|
||||
# mod folding
|
||||
(NOp.var('x') % NOp.cvar('c'), lambda x,c:\
|
||||
x-(x.vmin.arg//c.arg)*c.arg if 0 < c.arg and 0 <= x.vmin.arg and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None),
|
||||
# mul mod
|
||||
((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1:\
|
||||
x*(c0.arg%c1.arg)%c1 if 0 < c1.arg <= c0.arg else (x%(c1.arg//c0.arg))*c0 if c1.arg%c0.arg == 0 else None),
|
||||
|
||||
Reference in New Issue
Block a user