diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index da6202ae0a..20c1cceb61 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1297,11 +1297,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ repeats = argfix(repeats, *args) - base_shape = (1,) * (len(repeats) - self.ndim) + self.shape - new_shape = [x for b in base_shape for x in [1, b]] - expand_shape = [x for rs in zip(repeats, base_shape) for x in rs] + base_shape = _pad_left(self.shape, repeats)[0] + unsqueezed_shape = flatten([[1, s] for s in base_shape]) + expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)]) final_shape = [r*s for r,s in zip(repeats, base_shape)] - return self.reshape(new_shape).expand(expand_shape).reshape(final_shape) + return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape) def _resolve_dim(self, dim:int, *, outer:bool=False) -> int: if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):