fix PADTO optimization (#2935)

the correct condition is that PADTO cannot be applied to reduce axis, not Reduce.MAX in ops.
even for Reduce.SUM it's possible that the reduce axis had a div before, and the padded 0 became inf then sum over it is incorrect.
This commit is contained in:
chenyu
2023-12-25 22:52:49 -05:00
committed by GitHub
parent dca5e4fe74
commit 820f2e054e
5 changed files with 38 additions and 17 deletions

View File

@@ -519,25 +519,37 @@ class TestLinearizerOpts(unittest.TestCase):
helper_linearizer_opt(a@b, [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 1, 32)],
[Opt(OptOps.PADTO, 2, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)],
# can optimize further post PADTO
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UNROLL, 0, 4)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),],
])
def test_padto_max(self):
# pad uses invalid value 0, so max is not allowed
N = 17 * 17
a = -Tensor.ones(N, N)
helper_linearizer_opt(a.max(0), [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
helper_linearizer_opt(a.max(1), [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
# cannot pad a reduce axis
with self.assertRaises(AssertionError):
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
with self.assertRaises(AssertionError):
helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],])
def test_padto_where(self):
# pad uses invalid value 0, so kernel with max is not allowed
N = 17 * 17
a = (Tensor.rand(N, N).max(axis=0) > 1).where(1, 0)
with self.assertRaises(AssertionError):
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
a = (Tensor.rand(N, N).max(axis=0, keepdim=True) > 1).where(1, 0)
helper_linearizer_opt(a.max(0), [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
if __name__ == '__main__':
unittest.main()