linearizer: fix up edge case bugs in UNROLL opt (#3362)

Fully UNROLLing the first_reduce should not change the number of
local_dims.

Fully UNROLLing a GROUP dim should reduce the number of
group_for_reduces by one.

Also changed group_for_reduces to be a count as the axis number
isn't used anywhere (they are always the first reduce dims).
This commit is contained in:
Francis Lam
2024-02-10 02:49:25 -08:00
committed by GitHub
parent dc82ef6660
commit ddb22a60c8
3 changed files with 40 additions and 33 deletions

View File

@@ -383,7 +383,7 @@ class TestHandCodedOpts(unittest.TestCase):
k = Linearizer(s.ast)
k.hand_coded_optimizations()
assert len(k.group_for_reduce) == 1
assert k.group_for_reduces == 1
assert k.local_dims == 1
assert k.upcasted == 1
@@ -612,9 +612,14 @@ class TestLinearizerOpts(unittest.TestCase):
opts_shapes = [
([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]),
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]),
# TODO: fix these broken transformations
# ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
# ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]),
# check to ensure local_dims are stable for full UNROLL of first_reduce
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
# check behavior for full UNROLL on an existing GROUP
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]),
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]),
]
helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes])