UOp.factor and add chain sorting (#12413)

* add ordering

* fix some tests

* fix more tests

* shorten comment

* update test

* add rule and test

* add rule and test

* remove check

* use fold_divmod_congruence instead of simplify

* adjust tests

* shorten line

* new algo

* add test

* add function to un-nest the div

* add UOp.factor

* test UOp.factor

* uop_given_valid tries to factor simplex expression

* shorten line

* symbolic_flat is back

* change that back

* fix those new tests

* new rule for ordering

* factor multiple factors

* no symbolic_flat

* symbolic_flat to there

* move that back

* fix imports

* merge correctly

* linter happy

* add rule

* add a test

* cleanup

* revert that for now

* UOp.factor returns self instead of None

* try all_candidates

* remove or_else

* post index symbolic

* add test

* maket this closer to the original

* increase mac hlb_cifar min step time

* add some ordering tests

* cleanup

* increase pytest timeout time

* check dtype
This commit is contained in:
Sieds Lykles
2025-10-04 06:05:38 +02:00
committed by GitHub
parent 394dc24110
commit e74be4a140
7 changed files with 142 additions and 61 deletions

View File

@@ -109,7 +109,7 @@ jobs:
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
run: BENCHMARK_LOG=cifar_10steps JIT=1 ASSERT_MIN_STEP_TIME=320 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt
run: BENCHMARK_LOG=cifar_10steps JIT=1 ASSERT_MIN_STEP_TIME=330 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt
- name: Run 10 CIFAR training steps w HALF
run: BENCHMARK_LOG=cifar_10steps_half JIT=2 ASSERT_MIN_STEP_TIME=385 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt
#- name: Run 10 CIFAR training steps w BF16

View File

@@ -1,6 +1,6 @@
[pytest]
norecursedirs = extra
timeout = 180
timeout = 240
timeout_method = thread
timeout_func_only = true
testpaths = test

View File

@@ -28,22 +28,6 @@ class TestRewriteMap(unittest.TestCase):
self.assertIs(sub_map[a+b], e)
self.assertIs(sub_map[(a+b)*c], f)
def test_multistage_substitute(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
d = UOp.variable('d', 0, 10)
sub1 = {a+b:c}
start = (a+b)*c
# stage 1: (a+b)*c -> c*c
sub_map1 = graph_rewrite_map(start, _substitute, sub1, bottom_up=True)
self.assertIs(sub_map1[(a+b)*c], c*c)
# stage 2: c*c -> d
sub2 = {c*c:d}
sub_map2 = graph_rewrite_map(sub_map1[start], _substitute, sub2, input_map=sub_map1, bottom_up=True)
# (a+b)*c -> c*c -> d
self.assertIs(sub_map2[(a+b)*c], d)
def test_add_zero(self):
# Build a small graph: add(0, add(const=0, const=5))
zero_node = UOp.const(dtypes.index, 0)
@@ -144,11 +128,11 @@ class TestRewriteMap(unittest.TestCase):
yz_sum_zero = yz_sum + zero_node -> rewrites to yz_sum
yz_neg = -yz_sum_zero -> -(y+z)
yz_dneg = -yz_neg -> y+z (double neg gone)
x_plus_yz = x_var + yz_dneg -> x + (y+z)
double_neg_x = -(-x_plus_yz) -> x + (y+z)
final_expr = double_neg_x * one_node -> x + (y+z)
x_plus_yz = x_var + yz_dneg -> (x+y)+z (add nodes get sorted)
double_neg_x = -(-x_plus_yz) -> (x+y)+z
final_expr = double_neg_x * one_node -> (x+y)+z
We expect the final result to be (x + (y+z)).
We expect the final result to be ((x+y)+z).
Each original node should map to the final node that replaces it,
which might be structurally equivalent but not the same reference.
"""
@@ -163,9 +147,9 @@ class TestRewriteMap(unittest.TestCase):
yz_sum_zero = yz_sum + zero_node # (y + z) + 0
yz_neg = -yz_sum_zero # -(y+z)
yz_dneg = -yz_neg # -(-(y+z)) -> (y+z)
x_plus_yz = x_var + yz_dneg # x + (y+z)
double_neg_x = -(-x_plus_yz) # neg(neg(x+(y+z))) -> x+(y+z)
final_expr = double_neg_x * one_node # (x+(y+z)) * 1 -> x+(y+z)
x_plus_yz = x_var + yz_dneg # x + (y+z) -> (x+y)+z
double_neg_x = -(-x_plus_yz) # neg(neg(x+(y+z))) -> (x+y)+z
final_expr = double_neg_x * one_node # ((x+y)+z) * 1 -> (x+y)+z
node_map = graph_rewrite_map(final_expr, symbolic)
@@ -182,14 +166,15 @@ class TestRewriteMap(unittest.TestCase):
# -(-(y+z)) => (y+z)
self.assertEqual(node_map[yz_dneg], yz_sum)
# x + (y+z) => might get recreated if yz_dneg was changed, so compare to x + yz_sum
self.assertEqual(node_map[x_plus_yz], x_var + yz_sum)
# x + (y+z) => (x+y)+z
expected_xyz = (x_var + y_var) + z_var
self.assertEqual(node_map[x_plus_yz], expected_xyz)
# -(-(x+(y+z))) => x + (y+z)
self.assertEqual(node_map[double_neg_x], x_var + yz_sum)
# -(-(x+(y+z))) => (x+y)+z
self.assertEqual(node_map[double_neg_x], expected_xyz)
# (x+(y+z)) * 1 => x+(y+z)
self.assertEqual(node_map[final_expr], x_var + yz_sum)
# ((x+y)+z) * 1 => (x+y)+z
self.assertEqual(node_map[final_expr], expected_xyz)
# Unchanged atomic nodes map to themselves
self.assertEqual(node_map[x_var], x_var)

View File

@@ -60,7 +60,7 @@ class TestValidIdxSimplification(unittest.TestCase):
load = get_gated_load_uop(gate, idx)
self.check(load,
"0",
"(((lidx0+(gidx0*4))<19)!=True)")
"((((gidx0*4)+lidx0)<19)!=True)")
def test_simplify_within_valid1(self):
ridx0 = Range(0, 4)
@@ -184,7 +184,6 @@ class TestValidIdxSimplification(unittest.TestCase):
print("The expressions are not equivalent.")
print(s.model())
@unittest.expectedFailure # TODO: improve uop_given_valid
def test_valid_becomes_const2(self):
ridx0 = Range(0, 4)
ridx1 = Range(1, 4)
@@ -304,7 +303,7 @@ class TestImageSimplification(unittest.TestCase):
idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((((idx1*48)+(r2*6))+r0)+-6)", "(((idx2*2)+r1)+-1)")
self.check(load, None, "((((idx1*48)+r0)+(r2*6))+-6)", "(((idx2*2)+r1)+-1)")
def test_openpilot_conv2(self):
# conv in test/external/external_test_valid_remove.py
@@ -325,7 +324,7 @@ class TestImageSimplification(unittest.TestCase):
idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((((idx1*24)+(r2*3))+r0)+-3)", "(((idx2*2)+r1)+-1)")
self.check(load, None, "((((idx1*24)+r0)+(r2*3))+-3)", "(((idx2*2)+r1)+-1)")
def test_openpilot_conv3(self):
# in openpilot 0.9.7
@@ -347,8 +346,8 @@ class TestImageSimplification(unittest.TestCase):
self.check(load,
"((((idx2*2)+r0)<11)&((((idx1*8)+r1)<3)!=True))",
"(((idx0+((idx1*512)+(r1*64)))+832)%1024)",
"((((idx2*2)+r0)+(((idx1+((r1+5)//8))+1)//2))+-4)")
"(((idx0+(idx1*512))+(r1*64))+-192)",
"((((idx2*2)+(((idx1+((r1+5)//8))+1)//2))+r0)+-4)")
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
@@ -386,16 +385,16 @@ class TestImageSimplification(unittest.TestCase):
# TODO: can this be simplified further?
load = get_load_image_uop(shape, alu9, (((alu8+(alu2*8))%64),(alu2//8)))
self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+8)%64)", "((idx0%8)//2)")
self.check(load, "(idx0<256)", "((((idx0//32)+((idx0%8)*32))+8)%64)", "((idx0%8)//2)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu3*8))%64),(alu3//8)))
self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+16)%64)", "((idx0%8)//2)")
self.check(load, "(idx0<256)", "((((idx0//32)+((idx0%8)*32))+16)%64)", "((idx0%8)//2)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu4*8))%64),(alu4//8)))
self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+24)%64)", "((idx0%8)//2)")
self.check(load, "(idx0<256)", "((((idx0//32)+((idx0%8)*32))+24)%64)", "((idx0%8)//2)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu5*8))%64),(alu5//8)))
self.check(load, "(idx0<256)", "((((idx0%8)*32)+(idx0//32))%64)", "((idx0%8)//2)")
self.check(load, "(idx0<256)", "(((idx0//32)+((idx0%8)*32))%64)", "((idx0%8)//2)")
def test_simplify5(self):
# openpilot 0.9.7, chunk replacement to simplify

View File

@@ -116,6 +116,39 @@ class TestSymbolic(unittest.TestCase):
self.assertEqual((a*b*3+a*b*b).divide_exact(a*b).simplify(), b+3)
self.assertEqual((((a*-2)+14)*b).divide_exact(((a*-2)+14)).simplify(), b)
def helper_test_factor(self, expr, *factors):
factored = expr.factor(*factors)
self.check_equal_z3(expr, factored)
for fac in factors: self.assertIn(fac, factored.toposort())
def test_uop_factor(self):
a = Variable("a", 0, 8)
b = Variable("b", 0, 8)
c = Variable("c", 0, 8)
self.helper_test_factor((1400*a+2800*b), (a+2*b))
self.helper_test_factor((1400*a+2800*b)%9000, (a+2*b))
self.helper_test_factor((a+2*b), (a+2*b))
self.helper_test_factor((a+c+2*b), (a+2*b))
self.helper_test_factor((1400*a+c+2800*b)%9000, (a+2*b))
self.helper_test_factor((1399*a+c+2800*b)%9000+1400*a+2800*b, (a+2*b))
self.helper_test_factor((1400*a+c+2800*b)%9000+1400*a+2800*b, (a+2*b))
# self.assertIsNone((a+c+3*b).factor(a+2*b))
# self.assertIsNone((1399*a+c+2800*b).factor(a+2*b))
def test_uop_multiple_factors(self):
a = Variable("a", 0, 8)
b = Variable("b", 0, 8)
c = Variable("c", 0, 8)
d = Variable("d", 0, 8)
self.helper_test_factor((1400*a+2800*b+2*c+d), (a+2*b), (2*c+d))
self.helper_test_factor((100*a+200*b+5*c), (a+2*b), (5*c))
self.helper_test_factor((3*a+6*b+2*c+4*d), (a+2*b), (c+2*d))
self.helper_test_factor((7*a+14*b+3*c+6*d), (a+2*b), (3*c+6*d))
self.helper_test_factor((10*a+20*b+10*c+30*d), (a+2*b), (c+3*d))
self.helper_test_factor((10*c+(10*a+20*b)//3+30*d), (a+2*b), (c+3*d))
self.helper_test_factor((10*c+(10*a+20*b)//3+30*d), (a+2*b), (c+3*d))
# self.assertIsNone((7*a+14*b+3*c+6*d).factor((a+8*b), (2*c+6*d)))
def test_divide_exact_not(self):
a = Variable("a", 1, 8)
b = Variable("b", 1, 8)
@@ -130,13 +163,13 @@ class TestSymbolic(unittest.TestCase):
a = Variable("a", 0, 8)
b = Variable("b", 0, 8)
self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)")
self.helper_test_variable(b+a*2+a*3, 0, 8*6, "(b+(a*5))")
self.helper_test_variable(b+a*2+a*3, 0, 8*6, "((a*5)+b)")
def test_factorize_no_mul(self):
a = Variable("a", 0, 8)
b = Variable("b", 0, 8)
self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)")
self.helper_test_variable((a+b)+a*3, 0, 8*5, "(b+(a*4))")
self.helper_test_variable((a+b)+a*3, 0, 8*5, "((a*4)+b)")
self.helper_test_variable((a*3+b)+b*3, 0, 8*7, "((a*3)+(b*4))")
def test_neg(self):
@@ -159,8 +192,15 @@ class TestSymbolic(unittest.TestCase):
b = Variable("b", 0, 8)
self.helper_test_variable(a+a, 0, 16, "(a*2)")
self.helper_test_variable((a+b)+b, 0, 24, "(a+(b*2))")
self.helper_test_variable((a*3+b)+a, 0, 40, "(b+(a*4))")
self.helper_test_variable((a+b)+a*3, 0, 40, "(b+(a*4))")
self.helper_test_variable((a*3+b)+a, 0, 40, "((a*4)+b)")
self.helper_test_variable((a+b)+a*3, 0, 40, "((a*4)+b)")
def test_add_self_seperated(self):
a = Variable("a", 0, 8)
b = Variable("b", 0, 8)
c = Variable("c", 0, 8)
self.helper_test_variable((a+b)+c+a, 0, 32, "(((a*2)+b)+c)")
self.helper_test_variable((a*3+b*2)+c*2+a*5, 0, 96, "(((a*8)+(b*2))+(c*2))")
def test_sub_self(self):
a = Variable("a", 0, 8)
@@ -279,7 +319,7 @@ class TestSymbolic(unittest.TestCase):
def test_mod_congruence_multiple_vars(self):
self.helper_test_variable((9+9*Variable("x",0,3)+9*Variable("y",0,3))%10, 3, 9, "(((x*-1)+(y*-1))+9)")
self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9,
("(((z+(x*-1))+(y*-1))+7)", "(((y*-1)+(z+(x*-1)))+7)"))
("(((z+(x*-1))+(y*-1))+7)", "(((y*-1)+(z+(x*-1)))+7)", "((((x*-1)+(y*-1))+z)+7)"))
self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)")
def test_div_congruence(self):
@@ -455,7 +495,7 @@ class TestSymbolic(unittest.TestCase):
ridx1005 = UOp.variable("ridx1005", 0, 2)
ridx1006 = UOp.variable("ridx1006", 0, 2)
self.helper_test_variable((lidx1+((gidx1*18)+(ridx1005*18)+(lidx0*162))+(gidx0*2)+(ridx1006*2)+-40)//18, -2, 20,
"(((((lidx1+(((gidx1*18)+(ridx1005*18))+(lidx0*162)))+(gidx0*2))+(ridx1006*2))+-40)//18)")
"((((((((gidx0*2)+(gidx1*18))+(lidx0*162))+lidx1)+(ridx1005*18))+(ridx1006*2))+-40)//18)")
def test_add_div(self):
# careful about the lower bounds and upper bounds
@@ -498,7 +538,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((d1*a*b*d1)//(d1), -1000, 1000, "(a*(b*d1))", test_z3=False)
self.helper_test_variable((d1*a*d2*b*d1)//(d1*d2), -1000, 1000, "(a*(b*d1))", test_z3=False)
self.helper_test_variable((d1*a + b*d1)//(d1), -20, 20, "(a+b)", test_z3=False)
self.helper_test_variable((d1*a + b*d1 + c*d1)//(d1), -30, 30, "(c+(a+b))", test_z3=False)
self.helper_test_variable((d1*a + b*d1 + c*d1)//(d1), -30, 30, "((a+b)+c)", test_z3=False)
self.helper_test_variable((3*a*d1 + 9*b*d1)//(3*d1*d2), -40, 40, "(((a+(b*3))//(d2*-1))*-1)", test_z3=False)
self.helper_test_variable((3*a*d1 + 9*b*d1+3)//(3*d1*d2), -401, 399, "(((((a*d1)+((b*d1)*3))+1)//((d1*d2)*-1))*-1)", test_z3=False)
@@ -508,7 +548,7 @@ class TestSymbolic(unittest.TestCase):
d = Variable("d", 1, 10)
self.helper_test_variable((d*a+b)//d, 0, 20, "(a+(b//d))")
self.helper_test_variable((d*a*20+b)//(5*d), 0, 42, "((a*4)+(b//(d*5)))")
self.helper_test_variable((d*a*20+b*d*5+10)//(5*d), 0, 52, "((b+(a*4))+(2//d))")
self.helper_test_variable((d*a*20+b*d*5+10)//(5*d), 0, 52, "(((a*4)+b)+(2//d))")
def test_mod_gcd_factor_neg(self):
self.helper_test_variable((Variable("a", 0, 10)*-4+4)%8, -4, 4, "((((a*-1)+1)%2)*4)")
@@ -561,9 +601,7 @@ class TestSymbolic(unittest.TestCase):
lidx2 = Variable("lidx2", 0, 3)
alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10
self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192,
("((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(((gidx2*32)+(gidx1*8))+(lidx0*16)))",
"(((lidx1+((lidx2+(gidx0//5))//5))//2)+((gidx2*32)+((gidx1*8)+(lidx0*16))))",
"((((gidx1*8)+(gidx2*32))+(lidx0*16))+((lidx1+((lidx2+(gidx0//5))//5))//2))"))
("((((gidx1*8)+(gidx2*32))+(lidx0*16))+((lidx1+((lidx2+(gidx0//5))//5))//2))",))
def test_sum_div_complex2(self):
gidx0 = Variable("gidx0", 0, 7)
@@ -641,8 +679,21 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((gidx//4)*4+gidx%4, 0, 124, "gidx")
self.helper_test_variable(lidx+gidx%4+(gidx//4)*4, 0, 248, "(gidx+lidx)")
self.helper_test_variable(lidx+(gidx//4)*4+gidx%4, 0, 248, "(gidx+lidx)")
self.helper_test_variable(lidx+(gidx//4)*8+2*(gidx%4), 0, 372, "(lidx+(gidx*2))")
self.helper_test_variable(lidx+2*(gidx%4)+(gidx//4)*8, 0, 372, "(lidx+(gidx*2))")
self.helper_test_variable(lidx+(gidx//4)*8+2*(gidx%4), 0, 372, "((gidx*2)+lidx)")
self.helper_test_variable(lidx+2*(gidx%4)+(gidx//4)*8, 0, 372, "((gidx*2)+lidx)")
def test_div_mod_recombine_seperated(self):
gidx = Variable("gidx", 0, 124)
lidx = Variable("lidx", 0, 124)
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
c = Variable("c", 0, 3)
self.helper_test_variable(gidx%4+a+b+c+(gidx//4)*4, 0, 133, "(((a+b)+c)+gidx)")
self.helper_test_variable((gidx//4)*4+a+b*10+gidx%4, 0, 157, "((a+(b*10))+gidx)")
self.helper_test_variable(lidx+gidx%4+a+b+c//2+(gidx//4)*4, 0, 255, "((((a+b)+gidx)+lidx)+(c//2))")
self.helper_test_variable(lidx+(gidx//4)*8+b+c+a*8+2*(gidx%4), 0, 402, "(((((a*8)+b)+c)+(gidx*2))+lidx)")
# TODO: need better sorting for this one
# self.helper_test_variable(lidx+(gidx//4)*4+a*3+b*3+(c*10)%3+gidx%4, , , "")
def test_div_mod_recombine_folded_mod(self):
a = Variable("a", 0, 2)
@@ -1019,6 +1070,7 @@ class TestSymbolicRealWorld(unittest.TestCase):
("((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)",
'((lidx3+((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352)))+2207744)',
'((lidx3+((lidx4*100352)+((gidx2*8)+((gidx1*784)+((gidx0*3211264)+((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49)))))))+2207744)',
'((((((((gidx0*3211264)+(gidx1*784))+(gidx2*8))+lidx3)+(lidx4*100352))+(((lidx5+1)//16)*802816))+(((lidx5+1)%16)*49))+2207744)',
))
class TestBounds(unittest.TestCase):

View File

@@ -8,7 +8,7 @@ from tinygrad.uop.mathtraits import MathTrait
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, RANGEIFY, VIZ, SPEC
from tinygrad.helpers import strip_parens
from tinygrad.helpers import strip_parens, make_tuple
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer, MultiBuffer
@@ -150,6 +150,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def tuplize(self:UOp) -> tuple:
return (self.op.value, self.arg, self.dtype,)+tuple([x.tuplize for x in self.src])
@functools.cached_property
def order_add(self:UOp) -> tuple:
if self.op is Ops.MUL and self.src[1].op in (Ops.CONST, Ops.VCONST): return (self.src[0].tuplize, make_tuple(self.src[1].arg, 1))
return (self.tuplize, (0,))
@property
def ptrdtype(self) -> PtrDType:
if not isinstance(self.dtype, PtrDType): raise RuntimeError("ptrdtype called on UOp without PtrDType")
@@ -241,9 +246,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def simplify(self, tracked=False):
# late import!
from tinygrad.uop.symbolic import symbolic
from tinygrad.uop.symbolic import symbolic_flat
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
return graph_rewrite(self, symbolic, name="simplify")
return graph_rewrite(self, symbolic_flat, name="simplify")
def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
def _eval(self, dtype, expected_type:Type[T]) -> T:
assert self.dtype in dtype, f"eval with wrong dtype {self}"
@@ -568,6 +573,32 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
def factor(self, *factors: UOp) -> UOp:
# factor out expr from self if possible, might return self
# (1400*a + 2800*b + c).factor(a+2*b) -> 1400*(a+2*b) + c
if self.dtype in dtypes.floats: return self
if self.op is Ops.ADD:
factored = []
# dict of {term: const_factor}, i.e. {a: 1, b: 2}
remainders = dict([(u.divides(f:=u.const_factor()).simplify(),f) for u in self.split_uop(Ops.ADD)])
for fac in factors:
if fac.dtype not in (dtypes.index,)+dtypes.ints: continue
fac_terms = dict((u.divides(f:=u.const_factor()).simplify(),f) for u in fac.split_uop(Ops.ADD))
factored_terms = {k:v for k,v in remainders.items() if k in fac_terms}
new_remainders = {k:v for k,v in remainders.items() if k not in fac_terms}
if any(u not in factored_terms for u in fac_terms) or any(factored_terms[u]%fac_terms[u]!=0 for u in fac_terms) or not \
all_same(mul:=[factored_terms[u]//fac_terms[u] for u in fac_terms]):
continue
remainders = new_remainders
factored.append(fac*mul[0])
if not factored: return self
start = functools.reduce(operator.add, factored)
return sum([k.factor(*factors)*v for k,v in remainders.items()], start=start)
if self.op not in GroupOp.ALU|{Ops.VECTORIZE}: return self
return self.replace(src=tuple(s.factor(*factors) for s in self.src))
def pop_const(self, op=Ops.ADD) -> tuple[UOp, ConstType]:
return (self.src[0], self.src[1].arg) if self.op is op and self.src[1].op is Ops.CONST else (self, identity_element(op, self.dtype))
@staticmethod

View File

@@ -274,10 +274,16 @@ gep_pushing = PatternMatcher([
(UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma),
])
def chain_insert(chain, b, op):
if chain.op is not op or b.order_add > chain.src[1].order_add: return chain.alu(op, b)
return chain_insert(chain.src[0], b, op).alu(op, chain.src[1])
commutative = PatternMatcher([
# ** COMMUTATIVE flipping (only for index) **
# NOTE: this can break merging vector math by only flipping some of them
(UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
(UPat(GroupOp.Commutative-{Ops.ADD}, dtype=dtypes.index, name='x'), lambda x:
x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
(UPat(Ops.ADD, dtype=dtypes.index, name="x"), lambda x: functools.reduce(operator.add, sorted(x.split_uop(Ops.ADD), key=lambda u: u.order_add)))
])
symbolic = symbolic_simple+commutative+PatternMatcher([
@@ -373,7 +379,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
])+gep_pushing
symbolic_flat = symbolic+PatternMatcher([
# ** combine terms (opinionated) **
# ** combine terms (opinionated), can make it harder to substitute valids **
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
@@ -405,10 +411,13 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
# don't simplify any other gates, can lead to OOB, we substitute them back later
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
all_candidates = []
# simplify uop given that valid is True
for expr,v in bounds.items():
for i, (expr,v) in enumerate(bounds.items()):
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
# if the expr is an add we try and factorize so its more likely to substitute
if expr.op is Ops.ADD: uop = uop.factor(expr)
# some expr has lower bound > upper bound -> valid is an empty set and we return None
if v0 > v1: return None
# whole node became a const
@@ -421,7 +430,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)])
# try checking the whole clause
if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))])
if expr in uop.toposort():
candidates.append([tup:=(expr, UOp.variable(f"fake{i}", v0, v1, expr.dtype))])
all_candidates.append(tup)
for candidate in candidates:
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
@@ -431,6 +442,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
elif all_same(newuops): uop = newuops[0]
uop = uop.factor(*(e[0] for e in all_candidates))
uop = uop.substitute(sub_dict:=dict(all_candidates)).simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify()
# put the loads back in
uop = uop.substitute({v:k for k,v in load_subs.items()})
return uop