perf: use enumerate where possible (#1692)

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
Roelof van Dijk
2023-08-30 19:41:51 +02:00
committed by GitHub
parent a8aa13dc91
commit 62536d6000
6 changed files with 9 additions and 9 deletions

View File

@@ -408,8 +408,8 @@ class Tensor:
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False):
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if axis.__class__ is int else list(axis)) # type: ignore
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_]
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else self.shape[i] for i in range(len(self.shape))]))
shape = [s for i,s in enumerate(self.shape) if i not in axis_]
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)