mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into late_add_load
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start, ImageDType
|
||||
from tinygrad.uop.symbolic import symbolic_flat
|
||||
from tinygrad.helpers import partition, dedup
|
||||
@@ -18,9 +19,8 @@ pm_flatten_range = PatternMatcher([
|
||||
def count_divmod(x:UOp): return len([u for u in x.toposort() if u.op in {Ops.IDIV, Ops.MOD}])
|
||||
def simplify_merge_adjacent(u:UOp) -> UOp|None:
|
||||
reduce_ranges = [x.ranges for x in u.backward_slice_with_self if x.op is Ops.REDUCE]
|
||||
i = 0
|
||||
while i < len(u.ended_ranges)-1:
|
||||
r0, r1 = u.ended_ranges[i], u.ended_ranges[i+1]
|
||||
# on END we only want to merge adjacent ranges, on REDUCE we want to try all combinations
|
||||
for r0, r1 in (zip(u.ended_ranges, u.ended_ranges[1:]) if u.op is Ops.END else itertools.permutations(u.ended_ranges, 2)):
|
||||
# check same type
|
||||
if r0.arg[-1] == r1.arg[-1]:
|
||||
# check if the ranges to merge are in the same reduces
|
||||
@@ -35,7 +35,6 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None:
|
||||
if count_divmod(nidx) <= count_divmod(u):
|
||||
u = nidx
|
||||
continue
|
||||
i += 1
|
||||
return u
|
||||
|
||||
pm_simplify_ranges = PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user