From e777211a005e000ffb48231f87f6d46b6cbc2f58 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 16 Nov 2024 10:43:45 -0500 Subject: [PATCH] Tensor.repeat cleanup (#7735) flatten instead of double for loop comprehension --- tinygrad/tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index da6202ae0a..20c1cceb61 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1297,11 +1297,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ repeats = argfix(repeats, *args) - base_shape = (1,) * (len(repeats) - self.ndim) + self.shape - new_shape = [x for b in base_shape for x in [1, b]] - expand_shape = [x for rs in zip(repeats, base_shape) for x in rs] + base_shape = _pad_left(self.shape, repeats)[0] + unsqueezed_shape = flatten([[1, s] for s in base_shape]) + expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)]) final_shape = [r*s for r,s in zip(repeats, base_shape)] - return self.reshape(new_shape).expand(expand_shape).reshape(final_shape) + return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape) def _resolve_dim(self, dim:int, *, outer:bool=False) -> int: if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):