diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 98688461c8..857c737144 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -109,7 +109,7 @@ jobs: - name: Train MNIST run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt - name: Run 10 CIFAR training steps - run: BENCHMARK_LOG=cifar_10steps JIT=1 ASSERT_MIN_STEP_TIME=320 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt + run: BENCHMARK_LOG=cifar_10steps JIT=1 ASSERT_MIN_STEP_TIME=330 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt - name: Run 10 CIFAR training steps w HALF run: BENCHMARK_LOG=cifar_10steps_half JIT=2 ASSERT_MIN_STEP_TIME=385 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt #- name: Run 10 CIFAR training steps w BF16 diff --git a/pytest.ini b/pytest.ini index b9c3f6064a..fe28b7e961 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,6 @@ [pytest] norecursedirs = extra -timeout = 180 +timeout = 240 timeout_method = thread timeout_func_only = true testpaths = test diff --git a/test/unit/test_rewrite_map.py b/test/unit/test_rewrite_map.py index a299888725..0e4d4c7772 100644 --- a/test/unit/test_rewrite_map.py +++ b/test/unit/test_rewrite_map.py @@ -28,22 +28,6 @@ class TestRewriteMap(unittest.TestCase): self.assertIs(sub_map[a+b], e) self.assertIs(sub_map[(a+b)*c], f) - def test_multistage_substitute(self): - a = UOp.variable('a', 0, 10) - b = UOp.variable('b', 0, 10) - c = UOp.variable('c', 0, 10) - d = UOp.variable('d', 0, 10) - sub1 = {a+b:c} - start = (a+b)*c - # stage 1: (a+b)*c -> c*c - sub_map1 = graph_rewrite_map(start, _substitute, sub1, bottom_up=True) - self.assertIs(sub_map1[(a+b)*c], c*c) - # stage 2: c*c -> d - sub2 = {c*c:d} - sub_map2 = graph_rewrite_map(sub_map1[start], _substitute, sub2, input_map=sub_map1, bottom_up=True) - # (a+b)*c -> c*c -> d - self.assertIs(sub_map2[(a+b)*c], d) - def test_add_zero(self): # Build a small graph: add(0, add(const=0, const=5)) zero_node = UOp.const(dtypes.index, 0) @@ -144,11 +128,11 @@ class TestRewriteMap(unittest.TestCase): yz_sum_zero = yz_sum + zero_node -> rewrites to yz_sum yz_neg = -yz_sum_zero -> -(y+z) yz_dneg = -yz_neg -> y+z (double neg gone) - x_plus_yz = x_var + yz_dneg -> x + (y+z) - double_neg_x = -(-x_plus_yz) -> x + (y+z) - final_expr = double_neg_x * one_node -> x + (y+z) + x_plus_yz = x_var + yz_dneg -> (x+y)+z (add nodes get sorted) + double_neg_x = -(-x_plus_yz) -> (x+y)+z + final_expr = double_neg_x * one_node -> (x+y)+z - We expect the final result to be (x + (y+z)). + We expect the final result to be ((x+y)+z). Each original node should map to the final node that replaces it, which might be structurally equivalent but not the same reference. """ @@ -163,9 +147,9 @@ class TestRewriteMap(unittest.TestCase): yz_sum_zero = yz_sum + zero_node # (y + z) + 0 yz_neg = -yz_sum_zero # -(y+z) yz_dneg = -yz_neg # -(-(y+z)) -> (y+z) - x_plus_yz = x_var + yz_dneg # x + (y+z) - double_neg_x = -(-x_plus_yz) # neg(neg(x+(y+z))) -> x+(y+z) - final_expr = double_neg_x * one_node # (x+(y+z)) * 1 -> x+(y+z) + x_plus_yz = x_var + yz_dneg # x + (y+z) -> (x+y)+z + double_neg_x = -(-x_plus_yz) # neg(neg(x+(y+z))) -> (x+y)+z + final_expr = double_neg_x * one_node # ((x+y)+z) * 1 -> (x+y)+z node_map = graph_rewrite_map(final_expr, symbolic) @@ -182,14 +166,15 @@ class TestRewriteMap(unittest.TestCase): # -(-(y+z)) => (y+z) self.assertEqual(node_map[yz_dneg], yz_sum) - # x + (y+z) => might get recreated if yz_dneg was changed, so compare to x + yz_sum - self.assertEqual(node_map[x_plus_yz], x_var + yz_sum) + # x + (y+z) => (x+y)+z + expected_xyz = (x_var + y_var) + z_var + self.assertEqual(node_map[x_plus_yz], expected_xyz) - # -(-(x+(y+z))) => x + (y+z) - self.assertEqual(node_map[double_neg_x], x_var + yz_sum) + # -(-(x+(y+z))) => (x+y)+z + self.assertEqual(node_map[double_neg_x], expected_xyz) - # (x+(y+z)) * 1 => x+(y+z) - self.assertEqual(node_map[final_expr], x_var + yz_sum) + # ((x+y)+z) * 1 => (x+y)+z + self.assertEqual(node_map[final_expr], expected_xyz) # Unchanged atomic nodes map to themselves self.assertEqual(node_map[x_var], x_var) diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index bbc300a062..7a27cc83ca 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -60,7 +60,7 @@ class TestValidIdxSimplification(unittest.TestCase): load = get_gated_load_uop(gate, idx) self.check(load, "0", - "(((lidx0+(gidx0*4))<19)!=True)") + "((((gidx0*4)+lidx0)<19)!=True)") def test_simplify_within_valid1(self): ridx0 = Range(0, 4) @@ -184,7 +184,6 @@ class TestValidIdxSimplification(unittest.TestCase): print("The expressions are not equivalent.") print(s.model()) - @unittest.expectedFailure # TODO: improve uop_given_valid def test_valid_becomes_const2(self): ridx0 = Range(0, 4) ridx1 = Range(1, 4) @@ -304,7 +303,7 @@ class TestImageSimplification(unittest.TestCase): idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)) load = get_load_image_uop(shape, valid, idx) - self.check(load, None, "((((idx1*48)+(r2*6))+r0)+-6)", "(((idx2*2)+r1)+-1)") + self.check(load, None, "((((idx1*48)+r0)+(r2*6))+-6)", "(((idx2*2)+r1)+-1)") def test_openpilot_conv2(self): # conv in test/external/external_test_valid_remove.py @@ -325,7 +324,7 @@ class TestImageSimplification(unittest.TestCase): idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)) load = get_load_image_uop(shape, valid, idx) - self.check(load, None, "((((idx1*24)+(r2*3))+r0)+-3)", "(((idx2*2)+r1)+-1)") + self.check(load, None, "((((idx1*24)+r0)+(r2*3))+-3)", "(((idx2*2)+r1)+-1)") def test_openpilot_conv3(self): # in openpilot 0.9.7 @@ -347,8 +346,8 @@ class TestImageSimplification(unittest.TestCase): self.check(load, "((((idx2*2)+r0)<11)&((((idx1*8)+r1)<3)!=True))", - "(((idx0+((idx1*512)+(r1*64)))+832)%1024)", - "((((idx2*2)+r0)+(((idx1+((r1+5)//8))+1)//2))+-4)") + "(((idx0+(idx1*512))+(r1*64))+-192)", + "((((idx2*2)+(((idx1+((r1+5)//8))+1)//2))+r0)+-4)") def test_simplify1(self): # idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1) @@ -386,16 +385,16 @@ class TestImageSimplification(unittest.TestCase): # TODO: can this be simplified further? load = get_load_image_uop(shape, alu9, (((alu8+(alu2*8))%64),(alu2//8))) - self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+8)%64)", "((idx0%8)//2)") + self.check(load, "(idx0<256)", "((((idx0//32)+((idx0%8)*32))+8)%64)", "((idx0%8)//2)") load = get_load_image_uop(shape, alu9, (((alu8+(alu3*8))%64),(alu3//8))) - self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+16)%64)", "((idx0%8)//2)") + self.check(load, "(idx0<256)", "((((idx0//32)+((idx0%8)*32))+16)%64)", "((idx0%8)//2)") load = get_load_image_uop(shape, alu9, (((alu8+(alu4*8))%64),(alu4//8))) - self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+24)%64)", "((idx0%8)//2)") + self.check(load, "(idx0<256)", "((((idx0//32)+((idx0%8)*32))+24)%64)", "((idx0%8)//2)") load = get_load_image_uop(shape, alu9, (((alu8+(alu5*8))%64),(alu5//8))) - self.check(load, "(idx0<256)", "((((idx0%8)*32)+(idx0//32))%64)", "((idx0%8)//2)") + self.check(load, "(idx0<256)", "(((idx0//32)+((idx0%8)*32))%64)", "((idx0%8)//2)") def test_simplify5(self): # openpilot 0.9.7, chunk replacement to simplify diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 8c0bd638e5..b961efee60 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -116,6 +116,39 @@ class TestSymbolic(unittest.TestCase): self.assertEqual((a*b*3+a*b*b).divide_exact(a*b).simplify(), b+3) self.assertEqual((((a*-2)+14)*b).divide_exact(((a*-2)+14)).simplify(), b) + def helper_test_factor(self, expr, *factors): + factored = expr.factor(*factors) + self.check_equal_z3(expr, factored) + for fac in factors: self.assertIn(fac, factored.toposort()) + + def test_uop_factor(self): + a = Variable("a", 0, 8) + b = Variable("b", 0, 8) + c = Variable("c", 0, 8) + self.helper_test_factor((1400*a+2800*b), (a+2*b)) + self.helper_test_factor((1400*a+2800*b)%9000, (a+2*b)) + self.helper_test_factor((a+2*b), (a+2*b)) + self.helper_test_factor((a+c+2*b), (a+2*b)) + self.helper_test_factor((1400*a+c+2800*b)%9000, (a+2*b)) + self.helper_test_factor((1399*a+c+2800*b)%9000+1400*a+2800*b, (a+2*b)) + self.helper_test_factor((1400*a+c+2800*b)%9000+1400*a+2800*b, (a+2*b)) + # self.assertIsNone((a+c+3*b).factor(a+2*b)) + # self.assertIsNone((1399*a+c+2800*b).factor(a+2*b)) + + def test_uop_multiple_factors(self): + a = Variable("a", 0, 8) + b = Variable("b", 0, 8) + c = Variable("c", 0, 8) + d = Variable("d", 0, 8) + self.helper_test_factor((1400*a+2800*b+2*c+d), (a+2*b), (2*c+d)) + self.helper_test_factor((100*a+200*b+5*c), (a+2*b), (5*c)) + self.helper_test_factor((3*a+6*b+2*c+4*d), (a+2*b), (c+2*d)) + self.helper_test_factor((7*a+14*b+3*c+6*d), (a+2*b), (3*c+6*d)) + self.helper_test_factor((10*a+20*b+10*c+30*d), (a+2*b), (c+3*d)) + self.helper_test_factor((10*c+(10*a+20*b)//3+30*d), (a+2*b), (c+3*d)) + self.helper_test_factor((10*c+(10*a+20*b)//3+30*d), (a+2*b), (c+3*d)) + # self.assertIsNone((7*a+14*b+3*c+6*d).factor((a+8*b), (2*c+6*d))) + def test_divide_exact_not(self): a = Variable("a", 1, 8) b = Variable("b", 1, 8) @@ -130,13 +163,13 @@ class TestSymbolic(unittest.TestCase): a = Variable("a", 0, 8) b = Variable("b", 0, 8) self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)") - self.helper_test_variable(b+a*2+a*3, 0, 8*6, "(b+(a*5))") + self.helper_test_variable(b+a*2+a*3, 0, 8*6, "((a*5)+b)") def test_factorize_no_mul(self): a = Variable("a", 0, 8) b = Variable("b", 0, 8) self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)") - self.helper_test_variable((a+b)+a*3, 0, 8*5, "(b+(a*4))") + self.helper_test_variable((a+b)+a*3, 0, 8*5, "((a*4)+b)") self.helper_test_variable((a*3+b)+b*3, 0, 8*7, "((a*3)+(b*4))") def test_neg(self): @@ -159,8 +192,15 @@ class TestSymbolic(unittest.TestCase): b = Variable("b", 0, 8) self.helper_test_variable(a+a, 0, 16, "(a*2)") self.helper_test_variable((a+b)+b, 0, 24, "(a+(b*2))") - self.helper_test_variable((a*3+b)+a, 0, 40, "(b+(a*4))") - self.helper_test_variable((a+b)+a*3, 0, 40, "(b+(a*4))") + self.helper_test_variable((a*3+b)+a, 0, 40, "((a*4)+b)") + self.helper_test_variable((a+b)+a*3, 0, 40, "((a*4)+b)") + + def test_add_self_seperated(self): + a = Variable("a", 0, 8) + b = Variable("b", 0, 8) + c = Variable("c", 0, 8) + self.helper_test_variable((a+b)+c+a, 0, 32, "(((a*2)+b)+c)") + self.helper_test_variable((a*3+b*2)+c*2+a*5, 0, 96, "(((a*8)+(b*2))+(c*2))") def test_sub_self(self): a = Variable("a", 0, 8) @@ -279,7 +319,7 @@ class TestSymbolic(unittest.TestCase): def test_mod_congruence_multiple_vars(self): self.helper_test_variable((9+9*Variable("x",0,3)+9*Variable("y",0,3))%10, 3, 9, "(((x*-1)+(y*-1))+9)") self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9, - ("(((z+(x*-1))+(y*-1))+7)", "(((y*-1)+(z+(x*-1)))+7)")) + ("(((z+(x*-1))+(y*-1))+7)", "(((y*-1)+(z+(x*-1)))+7)", "((((x*-1)+(y*-1))+z)+7)")) self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)") def test_div_congruence(self): @@ -455,7 +495,7 @@ class TestSymbolic(unittest.TestCase): ridx1005 = UOp.variable("ridx1005", 0, 2) ridx1006 = UOp.variable("ridx1006", 0, 2) self.helper_test_variable((lidx1+((gidx1*18)+(ridx1005*18)+(lidx0*162))+(gidx0*2)+(ridx1006*2)+-40)//18, -2, 20, - "(((((lidx1+(((gidx1*18)+(ridx1005*18))+(lidx0*162)))+(gidx0*2))+(ridx1006*2))+-40)//18)") + "((((((((gidx0*2)+(gidx1*18))+(lidx0*162))+lidx1)+(ridx1005*18))+(ridx1006*2))+-40)//18)") def test_add_div(self): # careful about the lower bounds and upper bounds @@ -498,7 +538,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((d1*a*b*d1)//(d1), -1000, 1000, "(a*(b*d1))", test_z3=False) self.helper_test_variable((d1*a*d2*b*d1)//(d1*d2), -1000, 1000, "(a*(b*d1))", test_z3=False) self.helper_test_variable((d1*a + b*d1)//(d1), -20, 20, "(a+b)", test_z3=False) - self.helper_test_variable((d1*a + b*d1 + c*d1)//(d1), -30, 30, "(c+(a+b))", test_z3=False) + self.helper_test_variable((d1*a + b*d1 + c*d1)//(d1), -30, 30, "((a+b)+c)", test_z3=False) self.helper_test_variable((3*a*d1 + 9*b*d1)//(3*d1*d2), -40, 40, "(((a+(b*3))//(d2*-1))*-1)", test_z3=False) self.helper_test_variable((3*a*d1 + 9*b*d1+3)//(3*d1*d2), -401, 399, "(((((a*d1)+((b*d1)*3))+1)//((d1*d2)*-1))*-1)", test_z3=False) @@ -508,7 +548,7 @@ class TestSymbolic(unittest.TestCase): d = Variable("d", 1, 10) self.helper_test_variable((d*a+b)//d, 0, 20, "(a+(b//d))") self.helper_test_variable((d*a*20+b)//(5*d), 0, 42, "((a*4)+(b//(d*5)))") - self.helper_test_variable((d*a*20+b*d*5+10)//(5*d), 0, 52, "((b+(a*4))+(2//d))") + self.helper_test_variable((d*a*20+b*d*5+10)//(5*d), 0, 52, "(((a*4)+b)+(2//d))") def test_mod_gcd_factor_neg(self): self.helper_test_variable((Variable("a", 0, 10)*-4+4)%8, -4, 4, "((((a*-1)+1)%2)*4)") @@ -561,9 +601,7 @@ class TestSymbolic(unittest.TestCase): lidx2 = Variable("lidx2", 0, 3) alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10 self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, - ("((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(((gidx2*32)+(gidx1*8))+(lidx0*16)))", - "(((lidx1+((lidx2+(gidx0//5))//5))//2)+((gidx2*32)+((gidx1*8)+(lidx0*16))))", - "((((gidx1*8)+(gidx2*32))+(lidx0*16))+((lidx1+((lidx2+(gidx0//5))//5))//2))")) + ("((((gidx1*8)+(gidx2*32))+(lidx0*16))+((lidx1+((lidx2+(gidx0//5))//5))//2))",)) def test_sum_div_complex2(self): gidx0 = Variable("gidx0", 0, 7) @@ -641,8 +679,21 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((gidx//4)*4+gidx%4, 0, 124, "gidx") self.helper_test_variable(lidx+gidx%4+(gidx//4)*4, 0, 248, "(gidx+lidx)") self.helper_test_variable(lidx+(gidx//4)*4+gidx%4, 0, 248, "(gidx+lidx)") - self.helper_test_variable(lidx+(gidx//4)*8+2*(gidx%4), 0, 372, "(lidx+(gidx*2))") - self.helper_test_variable(lidx+2*(gidx%4)+(gidx//4)*8, 0, 372, "(lidx+(gidx*2))") + self.helper_test_variable(lidx+(gidx//4)*8+2*(gidx%4), 0, 372, "((gidx*2)+lidx)") + self.helper_test_variable(lidx+2*(gidx%4)+(gidx//4)*8, 0, 372, "((gidx*2)+lidx)") + + def test_div_mod_recombine_seperated(self): + gidx = Variable("gidx", 0, 124) + lidx = Variable("lidx", 0, 124) + a = Variable("a", 0, 3) + b = Variable("b", 0, 3) + c = Variable("c", 0, 3) + self.helper_test_variable(gidx%4+a+b+c+(gidx//4)*4, 0, 133, "(((a+b)+c)+gidx)") + self.helper_test_variable((gidx//4)*4+a+b*10+gidx%4, 0, 157, "((a+(b*10))+gidx)") + self.helper_test_variable(lidx+gidx%4+a+b+c//2+(gidx//4)*4, 0, 255, "((((a+b)+gidx)+lidx)+(c//2))") + self.helper_test_variable(lidx+(gidx//4)*8+b+c+a*8+2*(gidx%4), 0, 402, "(((((a*8)+b)+c)+(gidx*2))+lidx)") + # TODO: need better sorting for this one + # self.helper_test_variable(lidx+(gidx//4)*4+a*3+b*3+(c*10)%3+gidx%4, , , "") def test_div_mod_recombine_folded_mod(self): a = Variable("a", 0, 2) @@ -1019,6 +1070,7 @@ class TestSymbolicRealWorld(unittest.TestCase): ("((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)", '((lidx3+((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352)))+2207744)', '((lidx3+((lidx4*100352)+((gidx2*8)+((gidx1*784)+((gidx0*3211264)+((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49)))))))+2207744)', + '((((((((gidx0*3211264)+(gidx1*784))+(gidx2*8))+lidx3)+(lidx4*100352))+(((lidx5+1)//16)*802816))+(((lidx5+1)%16)*49))+2207744)', )) class TestBounds(unittest.TestCase): diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index c0a21bfa29..757bc4a51c 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -8,7 +8,7 @@ from tinygrad.uop.mathtraits import MathTrait from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, RANGEIFY, VIZ, SPEC -from tinygrad.helpers import strip_parens +from tinygrad.helpers import strip_parens, make_tuple if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer, MultiBuffer @@ -150,6 +150,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def tuplize(self:UOp) -> tuple: return (self.op.value, self.arg, self.dtype,)+tuple([x.tuplize for x in self.src]) + @functools.cached_property + def order_add(self:UOp) -> tuple: + if self.op is Ops.MUL and self.src[1].op in (Ops.CONST, Ops.VCONST): return (self.src[0].tuplize, make_tuple(self.src[1].arg, 1)) + return (self.tuplize, (0,)) + @property def ptrdtype(self) -> PtrDType: if not isinstance(self.dtype, PtrDType): raise RuntimeError("ptrdtype called on UOp without PtrDType") @@ -241,9 +246,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def simplify(self, tracked=False): # late import! - from tinygrad.uop.symbolic import symbolic + from tinygrad.uop.symbolic import symbolic_flat with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value): - return graph_rewrite(self, symbolic, name="simplify") + return graph_rewrite(self, symbolic_flat, name="simplify") def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret def _eval(self, dtype, expected_type:Type[T]) -> T: assert self.dtype in dtype, f"eval with wrong dtype {self}" @@ -568,6 +573,32 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 return None # generic None if we aren't sure + def factor(self, *factors: UOp) -> UOp: + # factor out expr from self if possible, might return self + # (1400*a + 2800*b + c).factor(a+2*b) -> 1400*(a+2*b) + c + if self.dtype in dtypes.floats: return self + if self.op is Ops.ADD: + factored = [] + # dict of {term: const_factor}, i.e. {a: 1, b: 2} + remainders = dict([(u.divides(f:=u.const_factor()).simplify(),f) for u in self.split_uop(Ops.ADD)]) + for fac in factors: + if fac.dtype not in (dtypes.index,)+dtypes.ints: continue + fac_terms = dict((u.divides(f:=u.const_factor()).simplify(),f) for u in fac.split_uop(Ops.ADD)) + factored_terms = {k:v for k,v in remainders.items() if k in fac_terms} + new_remainders = {k:v for k,v in remainders.items() if k not in fac_terms} + + if any(u not in factored_terms for u in fac_terms) or any(factored_terms[u]%fac_terms[u]!=0 for u in fac_terms) or not \ + all_same(mul:=[factored_terms[u]//fac_terms[u] for u in fac_terms]): + continue + + remainders = new_remainders + factored.append(fac*mul[0]) + if not factored: return self + start = functools.reduce(operator.add, factored) + return sum([k.factor(*factors)*v for k,v in remainders.items()], start=start) + + if self.op not in GroupOp.ALU|{Ops.VECTORIZE}: return self + return self.replace(src=tuple(s.factor(*factors) for s in self.src)) def pop_const(self, op=Ops.ADD) -> tuple[UOp, ConstType]: return (self.src[0], self.src[1].arg) if self.op is op and self.src[1].op is Ops.CONST else (self, identity_element(op, self.dtype)) @staticmethod diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 7a335339cf..7da4767712 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -274,10 +274,16 @@ gep_pushing = PatternMatcher([ (UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma), ]) +def chain_insert(chain, b, op): + if chain.op is not op or b.order_add > chain.src[1].order_add: return chain.alu(op, b) + return chain_insert(chain.src[0], b, op).alu(op, chain.src[1]) + commutative = PatternMatcher([ # ** COMMUTATIVE flipping (only for index) ** # NOTE: this can break merging vector math by only flipping some of them - (UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), + (UPat(GroupOp.Commutative-{Ops.ADD}, dtype=dtypes.index, name='x'), lambda x: + x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), + (UPat(Ops.ADD, dtype=dtypes.index, name="x"), lambda x: functools.reduce(operator.add, sorted(x.split_uop(Ops.ADD), key=lambda u: u.order_add))) ]) symbolic = symbolic_simple+commutative+PatternMatcher([ @@ -373,7 +379,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ])+gep_pushing symbolic_flat = symbolic+PatternMatcher([ - # ** combine terms (opinionated) ** + # ** combine terms (opinionated), can make it harder to substitute valids ** (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue ((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), @@ -405,10 +411,13 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: # don't simplify any other gates, can lead to OOB, we substitute them back later uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX})) + all_candidates = [] # simplify uop given that valid is True - for expr,v in bounds.items(): + for i, (expr,v) in enumerate(bounds.items()): v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop + # if the expr is an add we try and factorize so its more likely to substitute + if expr.op is Ops.ADD: uop = uop.factor(expr) # some expr has lower bound > upper bound -> valid is an empty set and we return None if v0 > v1: return None # whole node became a const @@ -421,7 +430,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)]) # try checking the whole clause - if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))]) + if expr in uop.toposort(): + candidates.append([tup:=(expr, UOp.variable(f"fake{i}", v0, v1, expr.dtype))]) + all_candidates.append(tup) for candidate in candidates: # if every branch in candidate gives the same simplified uop, we can rewrite the uop @@ -431,6 +442,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1])) elif all_same(newuops): uop = newuops[0] + uop = uop.factor(*(e[0] for e in all_candidates)) + uop = uop.substitute(sub_dict:=dict(all_candidates)).simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify() + # put the loads back in uop = uop.substitute({v:k for k,v in load_subs.items()}) return uop