Revert "move all reduces to the end in lazy (#3475)" (#3529)

This reverts commit 2113e1eb63.

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
George Hotz
2024-02-28 16:24:10 -08:00
committed by GitHub
parent 0c6846f9fc
commit 42eb8de0d4

View File

@@ -120,12 +120,8 @@ class LazyBuffer:
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
axis = tuple(x for x in axis if self.shape[x] != 1)
if len(axis) == 0: return self
# move all reduces to the end
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
permute_order = tuple(x for x in range(len(self.shape)) if x not in axis) + axis
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape).permute(permute_order),
self.dtype, op, tuple(range(len(self.shape)-len(axis), len(self.shape))),
(self.permute(permute_order),)).reshape(new_shape)
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))