mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
linearizer: fix up edge case bugs in UNROLL opt (#3362)
Fully UNROLLing the first_reduce should not change the number of local_dims. Fully UNROLLing a GROUP dim should reduce the number of group_for_reduces by one. Also changed group_for_reduces to be a count as the axis number isn't used anywhere (they are always the first reduce dims).
This commit is contained in:
@@ -383,7 +383,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
k = Linearizer(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
|
||||
assert len(k.group_for_reduce) == 1
|
||||
assert k.group_for_reduces == 1
|
||||
assert k.local_dims == 1
|
||||
assert k.upcasted == 1
|
||||
|
||||
@@ -612,9 +612,14 @@ class TestLinearizerOpts(unittest.TestCase):
|
||||
opts_shapes = [
|
||||
([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]),
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]),
|
||||
# TODO: fix these broken transformations
|
||||
# ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
# ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]),
|
||||
# check to ensure local_dims are stable for full UNROLL of first_reduce
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
# check behavior for full UNROLL on an existing GROUP
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]),
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]),
|
||||
]
|
||||
helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user