movementop only Tensor.roll (#11317)

* movementop only Tensor.roll

* fixed
This commit is contained in:
chenyu
2025-07-22 10:34:15 -04:00
committed by GitHub
parent a41140241b
commit 1d8b3e9d1c
3 changed files with 7 additions and 15 deletions

View File

@@ -97,14 +97,12 @@ class TestEmptyTensorEdgeCases(unittest.TestCase):
class TestRollEdgeCases(unittest.TestCase):
# we don't need more of these
@unittest.expectedFailure
def test_roll_mismatched_dims(self):
with self.assertRaises(RuntimeError):
torch.roll(torch.arange(9).reshape(3, 3), 1, dims=(0, 1))
with self.assertRaises(RuntimeError):
Tensor.arange(9).reshape(3, 3).roll(1, dims=(0, 1))
@unittest.expectedFailure
def test_roll_extra_shift(self):
# tinygrad ignores extra shift values instead of raising
with self.assertRaises(RuntimeError):

View File

@@ -1939,7 +1939,7 @@ class TestOps(unittest.TestCase):
def test_roll(self):
helper_test_op([(2, 4)], lambda x: x.roll(1))
helper_test_op([(2, 4)], lambda x: x.roll((1,)))
self.helper_test_exception([(2, 4)], lambda x: x.roll((1,2)), lambda x: x.roll((1,2)), expected=(RuntimeError, AssertionError))
self.helper_test_exception([(2, 4)], lambda x: x.roll((1,2)), lambda x: x.roll((1,2)), expected=RuntimeError)
helper_test_op([(2, 4)], lambda x: x.roll(1, 0))
helper_test_op([(2, 4)], lambda x: x.roll(-1, 0))
helper_test_op([(2, 4)], lambda x: x.roll(shifts=(2, 1), dims=(0, 1)))

View File

@@ -6,7 +6,7 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, get_single_element
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray
from tinygrad.gradient import compute_gradient
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, Variable, MathTrait, identity_element, all_metadata
from tinygrad.uop.spec import tensor_uop_spec, type_verify
@@ -1595,17 +1595,11 @@ class Tensor(MathTrait):
print(t.roll(shifts=-1, dims=0).numpy())
```
"""
if dims is None:
shifts = shifts if isinstance(shifts, int) else get_single_element(shifts)
if not isinstance(shifts, int): raise RuntimeError(f"{shifts=} must be int for {dims=}")
start = self.numel() - shifts % self.numel()
return self.flatten().repeat(2)[start:self.numel()+start].reshape(self.shape)
dims, rolled = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), self
for dim, shift in zip(dims, make_tuple(shifts, 1)):
shift = shift % self.shape[dim]
rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))],
rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
return rolled
if dims is None: return self.flatten().roll(shifts, 0).reshape(self.shape)
dims, shifts, slices = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), make_tuple(shifts, 1), [slice(None)] * self.ndim
if len(dims) != len(shifts): raise RuntimeError(f"{len(dims)=} != {len(shifts)=}")
for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim])
return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices]
def rearrange(self, formula:str, **sizes) -> Tensor:
"""