diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 3e5692900d..949e37d68f 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -591,45 +591,6 @@ class TestSymbolic(unittest.TestCase): with self.assertRaises(AssertionError): self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)") - def test_arange_unrolled4(self): - gidx = Variable("gidx", 0, 2559) - unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4 - self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)") - - def test_arange_unrolled4_with_cast(self): - gidx = Variable("gidx", 0, 2559, dtypes.index) - dt = dtypes.int - unrolled_div = ((gidx+2561)//4 + 2).cast(dt)+((gidx+2562)//4).cast(dt)+((gidx+2560)//4).cast(dt)+((gidx+2559)//4).cast(dt) - self.helper_test_variable(unrolled_div, 2561, 5120, "((int)(gidx)+2561)") - - def test_arange_unrolled4_mul(self): - gidx = Variable("gidx", 0, 2559) - unrolled_div = 2*((gidx+2561)//4)+2*((gidx+2562)//4)+2*((gidx+2560)//4)+2*((gidx+2559)//4) - self.helper_test_variable(unrolled_div, 5118, 10236, "((gidx*2)+5118)") - - def test_arange_unrolled4_small(self): - gidx = Variable("gidx", 0, 3) - unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4 - self.helper_test_variable(unrolled_div, 0, 3, "gidx") - - gidx = Variable("gidx", 0, 2) - unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4 - self.helper_test_variable(unrolled_div, 0, 2, "gidx") - - gidx = Variable("gidx", 0, 1) - unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4 - self.helper_test_variable(unrolled_div, 0, 1, "gidx") - - def test_arange_unrolled2(self): - gidx = Variable("gidx", 0, 2559) - unrolled_div = (gidx+2559)//2+(gidx+2560)//2+3 - self.helper_test_variable(unrolled_div, 2562, 5121, "(gidx+2562)") - - def test_arange_unrolled2_neg(self): - ridx = Variable("ridx", 0, 255) - unrolled_div = -((255-ridx)//2) - ((256-ridx)//2) - self.helper_test_variable(unrolled_div, -255, 0, "(ridx+-255)") - def test_gated_load(self): idx = Variable("idx", 0, 24) self.helper_test_variable(idx//4, 0, 6, "(idx//4)") diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 4da3e29e15..3a8957f57c 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -121,33 +121,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([ # ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ******** -def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None: - # div pattern in unrolled arange - # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x - seen_const, ans = [], None - for u in divs.split_uop(Ops.ADD): - if fac!=1: - if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None - u = u.src[0] - if u.op is Ops.CAST and u.src[0].dtype == dtypes.index: u = u.src[0] - if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None - if denominator != u.src[1].arg: return None - if (s0:=u.src[0]).vmin < 0: return None - # assumed CONST is the last of an ADD - if s0.op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST: - seen_const.append(s0.src[1].arg) - s0 = s0.src[0] - else: seen_const.append(0) - if ans is None: ans = s0 - if ans is not s0: return None - if ans is None: return None - # the first (denominator-len(seen_const)) terms may have been folded to 0 already - for i in range(denominator-len(seen_const)): - if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i) - if sorted(seen_const)==list(range(denominator)): - return (fac*ans).cast(divs.dtype) - return None - def lt_folding(x:UOp, c:int) -> UOp|None: p, np = partition(x.split_uop(Ops.ADD), lambda u: u.const_factor() == 1) if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d: @@ -350,9 +323,6 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), ((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), # *** rules from symbolic *** - # unrolled arange div folding - ((UPat()+(UPat()//UPat.cvar("d", vec=False)).or_casted()).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)), - ((UPat()+((UPat()//UPat.cvar("d", vec=False)).or_casted()*UPat.cvar("c"))).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)), # generic lt folding (UPat.var("x", dtypes.index)