From 602a145f8f5ff21fa64f9dc6c86c5203a6b4f7be Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Mon, 26 May 2025 23:15:44 +0800 Subject: [PATCH] Add Tensor.unfold (#10518) * yoinked 10272 * eitanturok's fixes * hmmm should size be sint? * add test --- docs/tensor/movement.md | 1 + extra/torch_backend/backend.py | 6 ++++-- test/test_ops.py | 13 +++++++++++++ tinygrad/tensor.py | 24 ++++++++++++++++++++++++ 4 files changed, 42 insertions(+), 2 deletions(-) diff --git a/docs/tensor/movement.md b/docs/tensor/movement.md index 3d35dbb4b5..7c40bdc94e 100644 --- a/docs/tensor/movement.md +++ b/docs/tensor/movement.md @@ -18,6 +18,7 @@ ::: tinygrad.Tensor.repeat_interleave ::: tinygrad.Tensor.split ::: tinygrad.Tensor.chunk +::: tinygrad.Tensor.unfold ::: tinygrad.Tensor.meshgrid ::: tinygrad.Tensor.squeeze ::: tinygrad.Tensor.unsqueeze diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 4a14a22ab7..a86b3d5477 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -102,7 +102,7 @@ def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False): @torch.library.impl("aten::index_put", "privateuseone") def index_put(self, indices, values, accumulate=False): - return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.cpu(), accumulate).tiny() + return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.clone().cpu(), accumulate).tiny() @torch.library.impl("aten::isin.Tensor_Tensor_out", "privateuseone") def isin_tensor_tensor_out(x, y, *, assume_unique=False, invert=False, out=None): return out.copy_(aten.isin(x.cpu(), y.cpu(), assume_unique=assume_unique, invert=invert).tiny()) @@ -391,6 +391,7 @@ decomps = [ aten.hardsigmoid_backward, aten.leaky_relu_backward, aten.nll_loss2d_forward, + aten.unfold_backward, # NOTE: many of these don't work or cause infinite loops #aten.var_mean, #aten.var, @@ -565,7 +566,8 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ "aten.ones_like": lambda self, dtype=None, device=None, **kwargs: self.ones_like(**{k: v for k, v in {"dtype": _from_torch_dtype(dtype) if dtype else None, "device": _from_torch_device(device) if device else None}.items() if v is not None}), - "aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64)) + "aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64)), + "aten.unfold": Tensor.unfold, }} def wrap_fxn(k,f): diff --git a/test/test_ops.py b/test/test_ops.py index 3fdd91fee3..f69f02209c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -227,6 +227,19 @@ class TestOps(unittest.TestCase): for i in range(len(tor)): helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) + def test_unfold(self): + helper_test_op([(8,)], lambda x: x.unfold(0, 2, 1)) + helper_test_op([(8,)], lambda x: x.unfold(0, 2, 2)) + helper_test_op([(8,)], lambda x: x.unfold(0, 7, 3)) + helper_test_op([(3,3,3)], lambda x: x.unfold(2, 2, 8)) + helper_test_op([(3,3,3)], lambda x: x.unfold(1, 0, 8)) + helper_test_op([(3,3,3,3,3)], lambda x: x.unfold(-1, 2, 2)) + + self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), lambda x: x.unfold(0, 9, 3), expected=RuntimeError) + self.helper_test_exception([(8,)], lambda x: x.unfold(1, 8, 3), lambda x: x.unfold(1, 8, 3), expected=IndexError) + self.helper_test_exception([(8,)], lambda x: x.unfold(0, -1, 3), lambda x: x.unfold(0, 9, 3), expected=RuntimeError) + self.helper_test_exception([(8,)], lambda x: x.unfold(0, 1, -1), lambda x: x.unfold(0, 9, 3), expected=RuntimeError) + def test_meshgrid(self): x, xt = torch.tensor([0.,1.,2.], requires_grad=True), Tensor([0.,1.,2.], requires_grad=True) y, yt = torch.tensor([3.,4.,5.,6.], requires_grad=True), Tensor([3.,4.,5.,6.], requires_grad=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b0c291380f..42ff44b38e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1394,6 +1394,30 @@ class Tensor(MathTrait): dim = self._resolve_dim(dim) return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim)) + def unfold(self, dim:int, size:sint, step:int) -> Tensor: + """ + Unfolds the tensor along dimension `dim` into overlapping windows. + + Each window has length `size` and begins every `step` elements of `self`. + Returns the input tensor with dimension `dim` replaced by dims `(n_windows, size)` + where `n_windows = (self.shape[dim] - size) // step + 1`. + + ```python exec="true" source="above" session="tensor" result="python" + unfolded = Tensor.arange(8).unfold(0,2,2) + print("\\n".join([repr(x.numpy()) for x in unfolded])) + ``` + ```python exec="true" source="above" session="tensor" result="python" + unfolded = Tensor.arange(27).reshape(3,3,3).unfold(-1,2,3) + print("\\n".join([repr(x.numpy()) for x in unfolded])) + ``` + """ + if size < 0: raise RuntimeError(f'size must be >= 0 but got {size=}') + if step <= 0: raise RuntimeError(f'step must be > 0 but got {step=}') + if size > self.shape[dim]: raise RuntimeError(f'maximum size for tensor at dimension {dim} is {self.shape[dim]} but size is {size}') + dim = self._resolve_dim(dim) + perm_to_last = tuple(i for i in range(self.ndim) if i != dim) + (dim,) + return self.permute(perm_to_last)._pool((size,), step).permute(argsort(perm_to_last) + (self.ndim,)) + def meshgrid(self:Tensor, *args:Tensor, indexing:Literal["ij", "xy"]="ij") -> tuple[Tensor, ...]: """ Generates coordinate matrices from coordinate vectors.