tensor roll (#6375)

* tensor roll function and tests

* fix type annotations

* reduce line count

* more readable
This commit is contained in:
Irakli Salia
2024-09-07 01:14:28 +04:00
committed by GitHub
parent dfb818788e
commit 2e01efc35f
2 changed files with 36 additions and 0 deletions

View File

@@ -1375,6 +1375,21 @@ class TestOps(unittest.TestCase):
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2)))
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1)))
def test_roll(self):
helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0))
helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, 0), lambda x: x.roll(-1, 0))
helper_test_op([(2, 4)], lambda x: torch.roll(x, shifts=(2, 1), dims=(0, 1)), lambda x: x.roll(shifts=(2, 1), dims=(0, 1)))
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0))
helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, -1), lambda x: x.roll(1, -1))
helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, -1), lambda x: x.roll(-1, -1))
helper_test_op([(2, 4)], lambda x: torch.roll(x, 5, 0), lambda x: x.roll(5, 0))
helper_test_op([(2, 4)], lambda x: torch.roll(x, -5, 0), lambda x: x.roll(-5, 0))
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(2, -3), dims=(0, 2)), lambda x: x.roll(shifts=(2, -3), dims=(0, 2)))
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(1, 2, -1), dims=(0, 1, 2)), lambda x: x.roll(shifts=(1, 2, -1), dims=(0, 1, 2)))
helper_test_op([(2, 4)], lambda x: torch.roll(x, 0, 0), lambda x: x.roll(0, 0))
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(0, 0), dims=(0, 1)), lambda x: x.roll(shifts=(0, 0), dims=(0, 1)))
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(0, 2), dims=(0, 1)), lambda x: x.roll(shifts=(0, 2), dims=(0, 1)))
def test_detach(self):
helper_test_op([(4,3,6,6)], lambda x: x.detach(), forward_only=True)
helper_test_op([()], lambda x: x.detach(), forward_only=True)

View File

@@ -1319,6 +1319,27 @@ class Tensor:
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor:
"""
Roll the tensor along specified dimension(s).
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.rand(3, 4, 1).roll(shifts=1, dims=0))
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.rand(3, 4, 1).roll(shifts=-1, dims=0))
```
"""
dims, shifts = (dims,) if isinstance(dims, int) else dims, (shifts,) if isinstance(shifts, int) else shifts
dims = tuple(i % len(self.shape) for i in dims)
all_shifts = [shifts[dims.index(i)] % self.shape[i] if i in dims else 0 for i in range(len(self.shape))]
rolled = self
for i, shift in enumerate(all_shifts):
rolled = Tensor.cat(rolled[tuple(slice(None) if j != i else slice(-shift, None) for j in range(len(rolled.shape)))],
rolled[tuple(slice(None) if j != i else slice(None, -shift) for j in range(len(rolled.shape)))], dim=i)
return rolled
# ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: