mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
movementop only Tensor.roll (#11317)
* movementop only Tensor.roll * fixed
This commit is contained in:
@@ -97,14 +97,12 @@ class TestEmptyTensorEdgeCases(unittest.TestCase):
|
||||
class TestRollEdgeCases(unittest.TestCase):
|
||||
# we don't need more of these
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_roll_mismatched_dims(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.roll(torch.arange(9).reshape(3, 3), 1, dims=(0, 1))
|
||||
with self.assertRaises(RuntimeError):
|
||||
Tensor.arange(9).reshape(3, 3).roll(1, dims=(0, 1))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_roll_extra_shift(self):
|
||||
# tinygrad ignores extra shift values instead of raising
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
||||
@@ -1939,7 +1939,7 @@ class TestOps(unittest.TestCase):
|
||||
def test_roll(self):
|
||||
helper_test_op([(2, 4)], lambda x: x.roll(1))
|
||||
helper_test_op([(2, 4)], lambda x: x.roll((1,)))
|
||||
self.helper_test_exception([(2, 4)], lambda x: x.roll((1,2)), lambda x: x.roll((1,2)), expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2, 4)], lambda x: x.roll((1,2)), lambda x: x.roll((1,2)), expected=RuntimeError)
|
||||
helper_test_op([(2, 4)], lambda x: x.roll(1, 0))
|
||||
helper_test_op([(2, 4)], lambda x: x.roll(-1, 0))
|
||||
helper_test_op([(2, 4)], lambda x: x.roll(shifts=(2, 1), dims=(0, 1)))
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, get_single_element
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, Variable, MathTrait, identity_element, all_metadata
|
||||
from tinygrad.uop.spec import tensor_uop_spec, type_verify
|
||||
@@ -1595,17 +1595,11 @@ class Tensor(MathTrait):
|
||||
print(t.roll(shifts=-1, dims=0).numpy())
|
||||
```
|
||||
"""
|
||||
if dims is None:
|
||||
shifts = shifts if isinstance(shifts, int) else get_single_element(shifts)
|
||||
if not isinstance(shifts, int): raise RuntimeError(f"{shifts=} must be int for {dims=}")
|
||||
start = self.numel() - shifts % self.numel()
|
||||
return self.flatten().repeat(2)[start:self.numel()+start].reshape(self.shape)
|
||||
dims, rolled = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), self
|
||||
for dim, shift in zip(dims, make_tuple(shifts, 1)):
|
||||
shift = shift % self.shape[dim]
|
||||
rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))],
|
||||
rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
|
||||
return rolled
|
||||
if dims is None: return self.flatten().roll(shifts, 0).reshape(self.shape)
|
||||
dims, shifts, slices = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), make_tuple(shifts, 1), [slice(None)] * self.ndim
|
||||
if len(dims) != len(shifts): raise RuntimeError(f"{len(dims)=} != {len(shifts)=}")
|
||||
for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim])
|
||||
return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices]
|
||||
|
||||
def rearrange(self, formula:str, **sizes) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user