mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 05:18:01 -05:00
tensor roll (#6375)
* tensor roll function and tests * fix type annotations * reduce line count * more readable
This commit is contained in:
@@ -1375,6 +1375,21 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1)))
|
||||
|
||||
def test_roll(self):
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0))
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, 0), lambda x: x.roll(-1, 0))
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, shifts=(2, 1), dims=(0, 1)), lambda x: x.roll(shifts=(2, 1), dims=(0, 1)))
|
||||
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0))
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, -1), lambda x: x.roll(1, -1))
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, -1), lambda x: x.roll(-1, -1))
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, 5, 0), lambda x: x.roll(5, 0))
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, -5, 0), lambda x: x.roll(-5, 0))
|
||||
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(2, -3), dims=(0, 2)), lambda x: x.roll(shifts=(2, -3), dims=(0, 2)))
|
||||
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(1, 2, -1), dims=(0, 1, 2)), lambda x: x.roll(shifts=(1, 2, -1), dims=(0, 1, 2)))
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, 0, 0), lambda x: x.roll(0, 0))
|
||||
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(0, 0), dims=(0, 1)), lambda x: x.roll(shifts=(0, 0), dims=(0, 1)))
|
||||
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(0, 2), dims=(0, 1)), lambda x: x.roll(shifts=(0, 2), dims=(0, 1)))
|
||||
|
||||
def test_detach(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.detach(), forward_only=True)
|
||||
helper_test_op([()], lambda x: x.detach(), forward_only=True)
|
||||
|
||||
@@ -1319,6 +1319,27 @@ class Tensor:
|
||||
dim = self._resolve_dim(dim)
|
||||
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
||||
|
||||
def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor:
|
||||
"""
|
||||
Roll the tensor along specified dimension(s).
|
||||
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.rand(3, 4, 1).roll(shifts=1, dims=0))
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.rand(3, 4, 1).roll(shifts=-1, dims=0))
|
||||
```
|
||||
"""
|
||||
dims, shifts = (dims,) if isinstance(dims, int) else dims, (shifts,) if isinstance(shifts, int) else shifts
|
||||
dims = tuple(i % len(self.shape) for i in dims)
|
||||
all_shifts = [shifts[dims.index(i)] % self.shape[i] if i in dims else 0 for i in range(len(self.shape))]
|
||||
rolled = self
|
||||
for i, shift in enumerate(all_shifts):
|
||||
rolled = Tensor.cat(rolled[tuple(slice(None) if j != i else slice(-shift, None) for j in range(len(rolled.shape)))],
|
||||
rolled[tuple(slice(None) if j != i else slice(None, -shift) for j in range(len(rolled.shape)))], dim=i)
|
||||
return rolled
|
||||
|
||||
# ***** reduce ops *****
|
||||
|
||||
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user