diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index bbd3553fd6..652ae30c4b 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -5,7 +5,7 @@ from tinygrad.dtype import PtrDType from tinygrad.helpers import DEBUG from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher -from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding, mod_folding +from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding simple_pm = PatternMatcher([ (NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)), @@ -614,49 +614,6 @@ class TestIFUOps(TestUOps): for st in sink.src: self.assertEqual(len(st.src), 3) -class TestDivMod(TestUOps): - def c(self, c:int): return UOp.const(dtypes.int, c) - def x(self, expr:str, nmin:int, nmax:int): return UOp(UOps.DEFINE_VAR, dtypes.int, (self.c(nmin), self.c(nmax)), Variable(expr, nmin, nmax)) - - # NOTE: does not simplify to the end - def test_const_mod(self): - self.assert_equiv_uops(mod_folding(self.c(6), 3), self.c(1)*self.c(0)) - self.assert_equiv_uops(mod_folding(self.c(7), 3), self.c(1)*self.c(1)) - self.assert_equiv_uops(mod_folding(self.c(8), 3), self.c(1)*self.c(2)) - - def test_var_mod(self): - self.assertIsNone(mod_folding(self.x("x", 0, 6), 3)) - self.assertIsNone(mod_folding(self.x("x", 0, 7), 3)) - - @unittest.skip("does not simplify to the end") - def test_add_mod(self): - self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)+40, 5), self.x("x", 0, 6)) - self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)-40, 5), self.x("x", 0, 6)) - self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)+42, 5), (self.x("x", 0, 6)+2)) - self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)-42, 5), (self.x("x", 0, 6)+3)) - self.assert_equiv_uops(mod_folding(40+self.x("x", 0, 6), 5), self.x("x", 0, 6)) - self.assert_equiv_uops(mod_folding(-40+self.x("x", 0, 6), 5), self.x("x", 0, 6)) - self.assert_equiv_uops(mod_folding(42+self.x("x", 0, 6), 5), (2+self.x("x", 0, 6))) - self.assert_equiv_uops(mod_folding(-42+self.x("x", 0, 6), 5), (3+self.x("x", 0, 6))) - - @unittest.skip("does not simplify to the end") - def test_mul_mod(self): - self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*40, 5), self.c(0)) - self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*-40, 5), self.c(0)) - self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*42, 5), (self.x("x", 0, 6)*2)) - self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*-42, 5), (self.x("x", 0, 6)*3)) - self.assert_equiv_uops(mod_folding(40*self.x("x", 0, 6), 5), self.c(0)) - self.assert_equiv_uops(mod_folding(-40*self.x("x", 0, 6), 5), self.c(0)) - self.assert_equiv_uops(mod_folding(42*self.x("x", 0, 6), 5), (2*self.x("x", 0, 6))) - self.assert_equiv_uops(mod_folding(-42*self.x("x", 0, 6), 5), (3*self.x("x", 0, 6))) - - @unittest.skip("does not simplify to the end now") - def test_mul_add_mod(self): - x = self.x("x", 0, 10) - y = self.x("y", 0, 10) - z = self.x("z", 0, 10) - self.assert_equiv_uops(mod_folding(x*40+y*12+z, 5), (y*2+z)) - if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 0ce6baa0c8..e483c40ac7 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -84,7 +84,7 @@ def _get_add_chain(x:UOp): else: yield x def mod_folding(x:UOp, c:int) -> Optional[UOp]: - # simplify x in x % c + # simplify x % c # None means no change remainder, something_changed = [], False for u in _get_add_chain(x): @@ -96,7 +96,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]: something_changed = True else: remainder.append(u) if not something_changed: return None - return functools.reduce(operator.add, remainder) if remainder else x.const(0) + return functools.reduce(operator.add, remainder)%c if remainder else x.const(0) def div_folding(x:UOp, c:int) -> Optional[UOp]: # simplify x // c, None means no change @@ -287,8 +287,8 @@ constant_folder = PatternMatcher([ (NOp.var('x') // NOp.cvar('c'), lambda x,c: newx if 0 < c.arg and not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None else None), # ** mod ** - # apply mod to mod input - (NOp.var('x') % NOp.cvar('c'), lambda x,c: newx%c if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), + # mod folding + (NOp.var('x') % NOp.cvar('c'), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), # remove mod (NOp.var('x') % NOp.cvar('c'), lambda x,c:\ x-(x.vmin.arg//c.arg)*c.arg if 0 < c.arg and 0 <= x.vmin.arg and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None),