mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix efficientnet slowness on rangeify (#12332)
This commit is contained in:
@@ -108,6 +108,13 @@ class TestRangeify(unittest.TestCase):
|
||||
w2 = Tensor.empty(12, 8, 3, 3)
|
||||
x.conv2d(w1).conv2d(w2).realize()
|
||||
|
||||
def test_xception_conv2d(self):
|
||||
# NOTE: this fusion is bad, it's recomputing the inner many times
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 1, 1)
|
||||
w2 = Tensor.empty(8, 1, 3, 3)
|
||||
x.conv2d(w1).conv2d(w2, groups=8).realize()
|
||||
|
||||
def test_conv_maxpool_contig(self): self.test_conv_maxpool(True)
|
||||
def test_conv_maxpool(self, contig=False):
|
||||
GlobalCounters.reset()
|
||||
|
||||
@@ -170,7 +170,10 @@ def map_expand(r:UOp, idx:UOp):
|
||||
else:
|
||||
ending_ranges.extend(axis_to_range)
|
||||
new_rngs.append(a.const_like(0))
|
||||
ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
|
||||
# if RANGEIFY >= 2, we are aggressive about not ending ranges
|
||||
if RANGEIFY >= 2: ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
|
||||
# if RANGEIFY=1, if it's ending at all we end it
|
||||
else: ending_ranges = [x.arg for x in ending_ranges]
|
||||
if idx.arg is not None: ending_ranges.append(idx.arg)
|
||||
return r.src[0].index(*new_rngs, arg=min(ending_ranges) if ending_ranges else None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user