diff --git a/test/test_edgecases.py b/test/test_edgecases.py index 001be44e0f..258fd2b602 100644 --- a/test/test_edgecases.py +++ b/test/test_edgecases.py @@ -97,14 +97,12 @@ class TestEmptyTensorEdgeCases(unittest.TestCase): class TestRollEdgeCases(unittest.TestCase): # we don't need more of these - @unittest.expectedFailure def test_roll_mismatched_dims(self): with self.assertRaises(RuntimeError): torch.roll(torch.arange(9).reshape(3, 3), 1, dims=(0, 1)) with self.assertRaises(RuntimeError): Tensor.arange(9).reshape(3, 3).roll(1, dims=(0, 1)) - @unittest.expectedFailure def test_roll_extra_shift(self): # tinygrad ignores extra shift values instead of raising with self.assertRaises(RuntimeError): diff --git a/test/test_ops.py b/test/test_ops.py index a3ee3e857c..eebe4a3fb4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1939,7 +1939,7 @@ class TestOps(unittest.TestCase): def test_roll(self): helper_test_op([(2, 4)], lambda x: x.roll(1)) helper_test_op([(2, 4)], lambda x: x.roll((1,))) - self.helper_test_exception([(2, 4)], lambda x: x.roll((1,2)), lambda x: x.roll((1,2)), expected=(RuntimeError, AssertionError)) + self.helper_test_exception([(2, 4)], lambda x: x.roll((1,2)), lambda x: x.roll((1,2)), expected=RuntimeError) helper_test_op([(2, 4)], lambda x: x.roll(1, 0)) helper_test_op([(2, 4)], lambda x: x.roll(-1, 0)) helper_test_op([(2, 4)], lambda x: x.roll(shifts=(2, 1), dims=(0, 1))) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 99018be74e..691684a2d0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -6,7 +6,7 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup -from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, get_single_element +from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray from tinygrad.gradient import compute_gradient from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, Variable, MathTrait, identity_element, all_metadata from tinygrad.uop.spec import tensor_uop_spec, type_verify @@ -1595,17 +1595,11 @@ class Tensor(MathTrait): print(t.roll(shifts=-1, dims=0).numpy()) ``` """ - if dims is None: - shifts = shifts if isinstance(shifts, int) else get_single_element(shifts) - if not isinstance(shifts, int): raise RuntimeError(f"{shifts=} must be int for {dims=}") - start = self.numel() - shifts % self.numel() - return self.flatten().repeat(2)[start:self.numel()+start].reshape(self.shape) - dims, rolled = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), self - for dim, shift in zip(dims, make_tuple(shifts, 1)): - shift = shift % self.shape[dim] - rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))], - rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim) - return rolled + if dims is None: return self.flatten().roll(shifts, 0).reshape(self.shape) + dims, shifts, slices = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), make_tuple(shifts, 1), [slice(None)] * self.ndim + if len(dims) != len(shifts): raise RuntimeError(f"{len(dims)=} != {len(shifts)=}") + for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim]) + return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices] def rearrange(self, formula:str, **sizes) -> Tensor: """