Tensor.roll touchup (#6398)

simplified a bit.
it might be able to write it with only movements, but the backward would contain a reduce.
This commit is contained in:
chenyu
2024-09-07 04:48:43 -04:00
committed by GitHub
parent 2e01efc35f
commit 3b2e1b922d

View File

@@ -1321,7 +1321,7 @@ class Tensor:
def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor:
"""
Roll the tensor along specified dimension(s).
Rolls the tensor along specified dimension(s).
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
```python exec="true" source="above" session="tensor" result="python"
@@ -1331,13 +1331,11 @@ class Tensor:
print(Tensor.rand(3, 4, 1).roll(shifts=-1, dims=0))
```
"""
dims, shifts = (dims,) if isinstance(dims, int) else dims, (shifts,) if isinstance(shifts, int) else shifts
dims = tuple(i % len(self.shape) for i in dims)
all_shifts = [shifts[dims.index(i)] % self.shape[i] if i in dims else 0 for i in range(len(self.shape))]
rolled = self
for i, shift in enumerate(all_shifts):
rolled = Tensor.cat(rolled[tuple(slice(None) if j != i else slice(-shift, None) for j in range(len(rolled.shape)))],
rolled[tuple(slice(None) if j != i else slice(None, -shift) for j in range(len(rolled.shape)))], dim=i)
dims, rolled = tuple(self._resolve_dim(d) for d in make_pair(dims, 1)), self
for dim, shift in zip(dims, make_pair(shifts, 1)):
shift = shift % self.shape[dim]
rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))],
rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
return rolled
# ***** reduce ops *****