clean up Tensor._reduce (#7382)

use make_tuple and self.ndim
This commit is contained in:
chenyu
2024-10-29 17:23:57 -04:00
committed by GitHub
parent 4ed2c40d48
commit f6abde95fa
2 changed files with 4 additions and 5 deletions

View File

@@ -32,7 +32,7 @@ def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 e
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
def ansilen(s:str): return len(ansistrip(s))
def make_tuple(x:Union[int, Tuple[int, ...]], cnt) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
def fully_flatten(l):
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):

View File

@@ -1447,10 +1447,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
if self.ndim == 0:
if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]")
axis = ()
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
axis_ = tuple(self._resolve_dim(x) for x in axis_)
ret = fxn.apply(self, axis=axis_)
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
ret = fxn.apply(self, axis=axis)
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
"""