mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove duplicated move_early logic in UOp.r [pr] (#10993)
This commit is contained in:
@@ -257,9 +257,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
if len(axis) == 0: return self
|
||||
# move any non reduce axis before the first reduce axis
|
||||
move_early = [i for i in range(axis[0], len(self.shape)) if i not in axis and resolve(self.shape[i] != 1)]
|
||||
move_early, rest = partition(range(axis[0], len(self.shape)), lambda i: i not in axis and resolve(self.shape[i] != 1))
|
||||
if move_early:
|
||||
permute = tuple(range(axis[0])) + tuple(move_early) + tuple([i for i in range(axis[0], len(self.shape)) if i not in move_early])
|
||||
permute = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
|
||||
ret = self.permute(permute)
|
||||
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])
|
||||
assert len(axis) == len(new_axis)
|
||||
|
||||
Reference in New Issue
Block a user