Tensor.repeat cleanup (#7735)

flatten instead of double for loop comprehension
This commit is contained in:
chenyu
2024-11-16 10:43:45 -05:00
committed by GitHub
parent f1efd84c92
commit e777211a00

View File

@@ -1297,11 +1297,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
```
"""
repeats = argfix(repeats, *args)
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
new_shape = [x for b in base_shape for x in [1, b]]
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
base_shape = _pad_left(self.shape, repeats)[0]
unsqueezed_shape = flatten([[1, s] for s in base_shape])
expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):