Add Tensor.unfold (#10518)

* yoinked 10272

* eitanturok's fixes

* hmmm should size be sint?

* add test
This commit is contained in:
geohotstan
2025-05-26 23:15:44 +08:00
committed by GitHub
parent 9169dcfb49
commit 602a145f8f
4 changed files with 42 additions and 2 deletions

View File

@@ -18,6 +18,7 @@
::: tinygrad.Tensor.repeat_interleave
::: tinygrad.Tensor.split
::: tinygrad.Tensor.chunk
::: tinygrad.Tensor.unfold
::: tinygrad.Tensor.meshgrid
::: tinygrad.Tensor.squeeze
::: tinygrad.Tensor.unsqueeze

View File

@@ -102,7 +102,7 @@ def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
@torch.library.impl("aten::index_put", "privateuseone")
def index_put(self, indices, values, accumulate=False):
return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.cpu(), accumulate).tiny()
return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.clone().cpu(), accumulate).tiny()
@torch.library.impl("aten::isin.Tensor_Tensor_out", "privateuseone")
def isin_tensor_tensor_out(x, y, *, assume_unique=False, invert=False, out=None): return out.copy_(aten.isin(x.cpu(), y.cpu(), assume_unique=assume_unique, invert=invert).tiny())
@@ -391,6 +391,7 @@ decomps = [
aten.hardsigmoid_backward,
aten.leaky_relu_backward,
aten.nll_loss2d_forward,
aten.unfold_backward,
# NOTE: many of these don't work or cause infinite loops
#aten.var_mean,
#aten.var,
@@ -565,7 +566,8 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.ones_like": lambda self, dtype=None, device=None, **kwargs:
self.ones_like(**{k: v for k, v in {"dtype": _from_torch_dtype(dtype) if dtype else None,
"device": _from_torch_device(device) if device else None}.items() if v is not None}),
"aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64))
"aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64)),
"aten.unfold": Tensor.unfold,
}}
def wrap_fxn(k,f):

View File

@@ -227,6 +227,19 @@ class TestOps(unittest.TestCase):
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
def test_unfold(self):
helper_test_op([(8,)], lambda x: x.unfold(0, 2, 1))
helper_test_op([(8,)], lambda x: x.unfold(0, 2, 2))
helper_test_op([(8,)], lambda x: x.unfold(0, 7, 3))
helper_test_op([(3,3,3)], lambda x: x.unfold(2, 2, 8))
helper_test_op([(3,3,3)], lambda x: x.unfold(1, 0, 8))
helper_test_op([(3,3,3,3,3)], lambda x: x.unfold(-1, 2, 2))
self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
self.helper_test_exception([(8,)], lambda x: x.unfold(1, 8, 3), lambda x: x.unfold(1, 8, 3), expected=IndexError)
self.helper_test_exception([(8,)], lambda x: x.unfold(0, -1, 3), lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
self.helper_test_exception([(8,)], lambda x: x.unfold(0, 1, -1), lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
def test_meshgrid(self):
x, xt = torch.tensor([0.,1.,2.], requires_grad=True), Tensor([0.,1.,2.], requires_grad=True)
y, yt = torch.tensor([3.,4.,5.,6.], requires_grad=True), Tensor([3.,4.,5.,6.], requires_grad=True)

View File

@@ -1394,6 +1394,30 @@ class Tensor(MathTrait):
dim = self._resolve_dim(dim)
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
def unfold(self, dim:int, size:sint, step:int) -> Tensor:
"""
Unfolds the tensor along dimension `dim` into overlapping windows.
Each window has length `size` and begins every `step` elements of `self`.
Returns the input tensor with dimension `dim` replaced by dims `(n_windows, size)`
where `n_windows = (self.shape[dim] - size) // step + 1`.
```python exec="true" source="above" session="tensor" result="python"
unfolded = Tensor.arange(8).unfold(0,2,2)
print("\\n".join([repr(x.numpy()) for x in unfolded]))
```
```python exec="true" source="above" session="tensor" result="python"
unfolded = Tensor.arange(27).reshape(3,3,3).unfold(-1,2,3)
print("\\n".join([repr(x.numpy()) for x in unfolded]))
```
"""
if size < 0: raise RuntimeError(f'size must be >= 0 but got {size=}')
if step <= 0: raise RuntimeError(f'step must be > 0 but got {step=}')
if size > self.shape[dim]: raise RuntimeError(f'maximum size for tensor at dimension {dim} is {self.shape[dim]} but size is {size}')
dim = self._resolve_dim(dim)
perm_to_last = tuple(i for i in range(self.ndim) if i != dim) + (dim,)
return self.permute(perm_to_last)._pool((size,), step).permute(argsort(perm_to_last) + (self.ndim,))
def meshgrid(self:Tensor, *args:Tensor, indexing:Literal["ij", "xy"]="ij") -> tuple[Tensor, ...]:
"""
Generates coordinate matrices from coordinate vectors.