remove duplicated move_early logic in UOp.r [pr] (#10993)

This commit is contained in:
chenyu
2025-06-26 18:33:54 -04:00
committed by GitHub
parent 579194f523
commit 4572e65f0f

View File

@@ -257,9 +257,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
if len(axis) == 0: return self
# move any non reduce axis before the first reduce axis
move_early = [i for i in range(axis[0], len(self.shape)) if i not in axis and resolve(self.shape[i] != 1)]
move_early, rest = partition(range(axis[0], len(self.shape)), lambda i: i not in axis and resolve(self.shape[i] != 1))
if move_early:
permute = tuple(range(axis[0])) + tuple(move_early) + tuple([i for i in range(axis[0], len(self.shape)) if i not in move_early])
permute = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
ret = self.permute(permute)
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])
assert len(axis) == len(new_axis)