diff --git a/test/test_arange.py b/test/test_arange.py index a9eeab7ee6..3109ad0581 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -42,7 +42,7 @@ class TestArange(unittest.TestCase): if Device.default.renderer.has_local: # TODO: fix limit - def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=100000) + def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=81920) def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496) def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 85ecc05083..1e0ae0a3cd 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -333,7 +333,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((Variable("a", 0, 5)+5)//4, 1, 2, "(((a+1)//4)+1)") def test_div_neg_rem(self): - self.helper_test_variable((-Variable("a", 0, 255)+256)//2, 0, 128, "(((a*-1)+256)//2)") + self.helper_test_variable((-Variable("a", 0, 255)+256)//2, 0, 128, "((((a+1)//2)*-1)+128)") def test_mul_div_factor_mul(self): self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)") diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index 9e8f0a4171..8e7dccee23 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -4,7 +4,7 @@ import math, operator, struct, functools from collections import defaultdict from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu from tinygrad.dtype import ConstType, dtypes, PtrDType -from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod +from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element from tinygrad.codegen.transcendental import xpow # ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ******** @@ -78,7 +78,7 @@ def split_uop(x:UOp, sep:Ops): def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None: # div pattern in unrolled arange # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x - seen_const, ans, offset = [], None, 0 + seen_const, ans = [], None for u in split_uop(divs, Ops.ADD): if fac!=1: if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None @@ -88,9 +88,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None: if (s0:=u.src[0]).vmin < 0: return None # assumed CONST is the last of an ADD if s0.op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST: - const = s0.src[1].arg - offset += cdiv(const, denominator) - seen_const.append(cmod(const, denominator)) + seen_const.append(s0.src[1].arg) s0 = s0.src[0] else: seen_const.append(0) if ans is None: ans = s0 @@ -100,7 +98,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None: for i in range(denominator-len(seen_const)): if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i) if sorted(seen_const)==list(range(denominator)): - return fac*(ans + offset) + return fac*ans return None def lt_folding(x:UOp, c:int) -> UOp|None: @@ -283,6 +281,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ # div folding ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d) (UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)), + ((UPat.var("x", dtypes.sints)+UPat.cvar("c")).named("n")//UPat.cvar("d"), + lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None), # ** mod ** # mod folding (UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),