diff --git a/test/test_rangeify.py b/test/test_rangeify.py index 302d0ebd43..05fd2bde69 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -108,6 +108,13 @@ class TestRangeify(unittest.TestCase): w2 = Tensor.empty(12, 8, 3, 3) x.conv2d(w1).conv2d(w2).realize() + def test_xception_conv2d(self): + # NOTE: this fusion is bad, it's recomputing the inner many times + x = Tensor.empty(1, 4, 32, 32) + w1 = Tensor.empty(8, 4, 1, 1) + w2 = Tensor.empty(8, 1, 3, 3) + x.conv2d(w1).conv2d(w2, groups=8).realize() + def test_conv_maxpool_contig(self): self.test_conv_maxpool(True) def test_conv_maxpool(self, contig=False): GlobalCounters.reset() diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 97cc691066..d8396eae93 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -170,7 +170,10 @@ def map_expand(r:UOp, idx:UOp): else: ending_ranges.extend(axis_to_range) new_rngs.append(a.const_like(0)) - ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges] + # if RANGEIFY >= 2, we are aggressive about not ending ranges + if RANGEIFY >= 2: ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges] + # if RANGEIFY=1, if it's ending at all we end it + else: ending_ranges = [x.arg for x in ending_ranges] if idx.arg is not None: ending_ranges.append(idx.arg) return r.src[0].index(*new_rngs, arg=min(ending_ranges) if ending_ranges else None)