mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Tensor.roll touchup (#6398)
simplified a bit. it might be able to write it with only movements, but the backward would contain a reduce.
This commit is contained in:
@@ -1321,7 +1321,7 @@ class Tensor:
|
||||
|
||||
def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor:
|
||||
"""
|
||||
Roll the tensor along specified dimension(s).
|
||||
Rolls the tensor along specified dimension(s).
|
||||
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
@@ -1331,13 +1331,11 @@ class Tensor:
|
||||
print(Tensor.rand(3, 4, 1).roll(shifts=-1, dims=0))
|
||||
```
|
||||
"""
|
||||
dims, shifts = (dims,) if isinstance(dims, int) else dims, (shifts,) if isinstance(shifts, int) else shifts
|
||||
dims = tuple(i % len(self.shape) for i in dims)
|
||||
all_shifts = [shifts[dims.index(i)] % self.shape[i] if i in dims else 0 for i in range(len(self.shape))]
|
||||
rolled = self
|
||||
for i, shift in enumerate(all_shifts):
|
||||
rolled = Tensor.cat(rolled[tuple(slice(None) if j != i else slice(-shift, None) for j in range(len(rolled.shape)))],
|
||||
rolled[tuple(slice(None) if j != i else slice(None, -shift) for j in range(len(rolled.shape)))], dim=i)
|
||||
dims, rolled = tuple(self._resolve_dim(d) for d in make_pair(dims, 1)), self
|
||||
for dim, shift in zip(dims, make_pair(shifts, 1)):
|
||||
shift = shift % self.shape[dim]
|
||||
rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))],
|
||||
rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
|
||||
return rolled
|
||||
|
||||
# ***** reduce ops *****
|
||||
|
||||
Reference in New Issue
Block a user