mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
move % inside UOp mod_folding and remove deprecated tests (#6085)
[run_process_replay]
This commit is contained in:
@@ -5,7 +5,7 @@ from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps
|
||||
from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher
|
||||
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding, mod_folding
|
||||
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
@@ -614,49 +614,6 @@ class TestIFUOps(TestUOps):
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
|
||||
class TestDivMod(TestUOps):
|
||||
def c(self, c:int): return UOp.const(dtypes.int, c)
|
||||
def x(self, expr:str, nmin:int, nmax:int): return UOp(UOps.DEFINE_VAR, dtypes.int, (self.c(nmin), self.c(nmax)), Variable(expr, nmin, nmax))
|
||||
|
||||
# NOTE: does not simplify to the end
|
||||
def test_const_mod(self):
|
||||
self.assert_equiv_uops(mod_folding(self.c(6), 3), self.c(1)*self.c(0))
|
||||
self.assert_equiv_uops(mod_folding(self.c(7), 3), self.c(1)*self.c(1))
|
||||
self.assert_equiv_uops(mod_folding(self.c(8), 3), self.c(1)*self.c(2))
|
||||
|
||||
def test_var_mod(self):
|
||||
self.assertIsNone(mod_folding(self.x("x", 0, 6), 3))
|
||||
self.assertIsNone(mod_folding(self.x("x", 0, 7), 3))
|
||||
|
||||
@unittest.skip("does not simplify to the end")
|
||||
def test_add_mod(self):
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)+40, 5), self.x("x", 0, 6))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)-40, 5), self.x("x", 0, 6))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)+42, 5), (self.x("x", 0, 6)+2))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)-42, 5), (self.x("x", 0, 6)+3))
|
||||
self.assert_equiv_uops(mod_folding(40+self.x("x", 0, 6), 5), self.x("x", 0, 6))
|
||||
self.assert_equiv_uops(mod_folding(-40+self.x("x", 0, 6), 5), self.x("x", 0, 6))
|
||||
self.assert_equiv_uops(mod_folding(42+self.x("x", 0, 6), 5), (2+self.x("x", 0, 6)))
|
||||
self.assert_equiv_uops(mod_folding(-42+self.x("x", 0, 6), 5), (3+self.x("x", 0, 6)))
|
||||
|
||||
@unittest.skip("does not simplify to the end")
|
||||
def test_mul_mod(self):
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*40, 5), self.c(0))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*-40, 5), self.c(0))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*42, 5), (self.x("x", 0, 6)*2))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*-42, 5), (self.x("x", 0, 6)*3))
|
||||
self.assert_equiv_uops(mod_folding(40*self.x("x", 0, 6), 5), self.c(0))
|
||||
self.assert_equiv_uops(mod_folding(-40*self.x("x", 0, 6), 5), self.c(0))
|
||||
self.assert_equiv_uops(mod_folding(42*self.x("x", 0, 6), 5), (2*self.x("x", 0, 6)))
|
||||
self.assert_equiv_uops(mod_folding(-42*self.x("x", 0, 6), 5), (3*self.x("x", 0, 6)))
|
||||
|
||||
@unittest.skip("does not simplify to the end now")
|
||||
def test_mul_add_mod(self):
|
||||
x = self.x("x", 0, 10)
|
||||
y = self.x("y", 0, 10)
|
||||
z = self.x("z", 0, 10)
|
||||
self.assert_equiv_uops(mod_folding(x*40+y*12+z, 5), (y*2+z))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -84,7 +84,7 @@ def _get_add_chain(x:UOp):
|
||||
else: yield x
|
||||
|
||||
def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
# simplify x in x % c
|
||||
# simplify x % c
|
||||
# None means no change
|
||||
remainder, something_changed = [], False
|
||||
for u in _get_add_chain(x):
|
||||
@@ -96,7 +96,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
something_changed = True
|
||||
else: remainder.append(u)
|
||||
if not something_changed: return None
|
||||
return functools.reduce(operator.add, remainder) if remainder else x.const(0)
|
||||
return functools.reduce(operator.add, remainder)%c if remainder else x.const(0)
|
||||
|
||||
def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
# simplify x // c, None means no change
|
||||
@@ -287,8 +287,8 @@ constant_folder = PatternMatcher([
|
||||
(NOp.var('x') // NOp.cvar('c'), lambda x,c:
|
||||
newx if 0 < c.arg and not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None else None),
|
||||
# ** mod **
|
||||
# apply mod to mod input
|
||||
(NOp.var('x') % NOp.cvar('c'), lambda x,c: newx%c if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
# mod folding
|
||||
(NOp.var('x') % NOp.cvar('c'), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
# remove mod
|
||||
(NOp.var('x') % NOp.cvar('c'), lambda x,c:\
|
||||
x-(x.vmin.arg//c.arg)*c.arg if 0 < c.arg and 0 <= x.vmin.arg and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None),
|
||||
|
||||
Reference in New Issue
Block a user