From 5ce8a1d2f2ecf1d2e816007f1d01c8ad9f56f72c Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Wed, 29 Oct 2025 05:04:54 +0100 Subject: [PATCH] Merge adjacent try all permutations for reduce (#12972) --- tinygrad/codegen/simplify.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index 35df3d304e..5562a32603 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -1,3 +1,4 @@ +import itertools from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start, ImageDType from tinygrad.uop.symbolic import symbolic_flat from tinygrad.helpers import partition, dedup @@ -18,9 +19,8 @@ pm_flatten_range = PatternMatcher([ def count_divmod(x:UOp): return len([u for u in x.toposort() if u.op in {Ops.IDIV, Ops.MOD}]) def simplify_merge_adjacent(u:UOp) -> UOp|None: reduce_ranges = [x.ranges for x in u.backward_slice_with_self if x.op is Ops.REDUCE] - i = 0 - while i < len(u.ended_ranges)-1: - r0, r1 = u.ended_ranges[i], u.ended_ranges[i+1] + # on END we only want to merge adjacent ranges, on REDUCE we want to try all combinations + for r0, r1 in (zip(u.ended_ranges, u.ended_ranges[1:]) if u.op is Ops.END else itertools.permutations(u.ended_ranges, 2)): # check same type if r0.arg[-1] == r1.arg[-1]: # check if the ranges to merge are in the same reduces @@ -35,7 +35,6 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None: if count_divmod(nidx) <= count_divmod(u): u = nidx continue - i += 1 return u pm_simplify_ranges = PatternMatcher([