mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
Tensor.repeat cleanup (#7735)
flatten instead of double for loop comprehension
This commit is contained in:
@@ -1297,11 +1297,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
```
|
||||
"""
|
||||
repeats = argfix(repeats, *args)
|
||||
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
|
||||
new_shape = [x for b in base_shape for x in [1, b]]
|
||||
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
|
||||
base_shape = _pad_left(self.shape, repeats)[0]
|
||||
unsqueezed_shape = flatten([[1, s] for s in base_shape])
|
||||
expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
|
||||
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
||||
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
||||
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
|
||||
|
||||
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
|
||||
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
|
||||
|
||||
Reference in New Issue
Block a user