diff --git a/test/test_ops.py b/test/test_ops.py index ea841dd6c4..cadceb2c7d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1375,6 +1375,21 @@ class TestOps(unittest.TestCase): helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2))) helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1))) + def test_roll(self): + helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0)) + helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, 0), lambda x: x.roll(-1, 0)) + helper_test_op([(2, 4)], lambda x: torch.roll(x, shifts=(2, 1), dims=(0, 1)), lambda x: x.roll(shifts=(2, 1), dims=(0, 1))) + helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0)) + helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, -1), lambda x: x.roll(1, -1)) + helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, -1), lambda x: x.roll(-1, -1)) + helper_test_op([(2, 4)], lambda x: torch.roll(x, 5, 0), lambda x: x.roll(5, 0)) + helper_test_op([(2, 4)], lambda x: torch.roll(x, -5, 0), lambda x: x.roll(-5, 0)) + helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(2, -3), dims=(0, 2)), lambda x: x.roll(shifts=(2, -3), dims=(0, 2))) + helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(1, 2, -1), dims=(0, 1, 2)), lambda x: x.roll(shifts=(1, 2, -1), dims=(0, 1, 2))) + helper_test_op([(2, 4)], lambda x: torch.roll(x, 0, 0), lambda x: x.roll(0, 0)) + helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(0, 0), dims=(0, 1)), lambda x: x.roll(shifts=(0, 0), dims=(0, 1))) + helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(0, 2), dims=(0, 1)), lambda x: x.roll(shifts=(0, 2), dims=(0, 1))) + def test_detach(self): helper_test_op([(4,3,6,6)], lambda x: x.detach(), forward_only=True) helper_test_op([()], lambda x: x.detach(), forward_only=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0606eb7333..7d7df4f9c0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1319,6 +1319,27 @@ class Tensor: dim = self._resolve_dim(dim) return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:]) + def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor: + """ + Roll the tensor along specified dimension(s). + The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.rand(3, 4, 1).roll(shifts=1, dims=0)) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.rand(3, 4, 1).roll(shifts=-1, dims=0)) + ``` + """ + dims, shifts = (dims,) if isinstance(dims, int) else dims, (shifts,) if isinstance(shifts, int) else shifts + dims = tuple(i % len(self.shape) for i in dims) + all_shifts = [shifts[dims.index(i)] % self.shape[i] if i in dims else 0 for i in range(len(self.shape))] + rolled = self + for i, shift in enumerate(all_shifts): + rolled = Tensor.cat(rolled[tuple(slice(None) if j != i else slice(-shift, None) for j in range(len(rolled.shape)))], + rolled[tuple(slice(None) if j != i else slice(None, -shift) for j in range(len(rolled.shape)))], dim=i) + return rolled + # ***** reduce ops ***** def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: