mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
This reverts commit 2113e1eb63.
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -120,12 +120,8 @@ class LazyBuffer:
|
||||
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
||||
axis = tuple(x for x in axis if self.shape[x] != 1)
|
||||
if len(axis) == 0: return self
|
||||
# move all reduces to the end
|
||||
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
permute_order = tuple(x for x in range(len(self.shape)) if x not in axis) + axis
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape).permute(permute_order),
|
||||
self.dtype, op, tuple(range(len(self.shape)-len(axis), len(self.shape))),
|
||||
(self.permute(permute_order),)).reshape(new_shape)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
|
||||
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
|
||||
Reference in New Issue
Block a user