fix efficientnet slowness on rangeify (#12332)

This commit is contained in:
George Hotz
2025-09-29 20:01:01 +10:00
committed by GitHub
parent 9d2f2b8e34
commit 3291e00df7
2 changed files with 11 additions and 1 deletions

View File

@@ -108,6 +108,13 @@ class TestRangeify(unittest.TestCase):
w2 = Tensor.empty(12, 8, 3, 3)
x.conv2d(w1).conv2d(w2).realize()
def test_xception_conv2d(self):
# NOTE: this fusion is bad, it's recomputing the inner many times
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 1, 1)
w2 = Tensor.empty(8, 1, 3, 3)
x.conv2d(w1).conv2d(w2, groups=8).realize()
def test_conv_maxpool_contig(self): self.test_conv_maxpool(True)
def test_conv_maxpool(self, contig=False):
GlobalCounters.reset()

View File

@@ -170,7 +170,10 @@ def map_expand(r:UOp, idx:UOp):
else:
ending_ranges.extend(axis_to_range)
new_rngs.append(a.const_like(0))
ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
# if RANGEIFY >= 2, we are aggressive about not ending ranges
if RANGEIFY >= 2: ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
# if RANGEIFY=1, if it's ending at all we end it
else: ending_ranges = [x.arg for x in ending_ranges]
if idx.arg is not None: ending_ranges.append(idx.arg)
return r.src[0].index(*new_rngs, arg=min(ending_ranges) if ending_ranges else None)