mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Add Tensor.unfold (#10518)
* yoinked 10272 * eitanturok's fixes * hmmm should size be sint? * add test
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
::: tinygrad.Tensor.repeat_interleave
|
||||
::: tinygrad.Tensor.split
|
||||
::: tinygrad.Tensor.chunk
|
||||
::: tinygrad.Tensor.unfold
|
||||
::: tinygrad.Tensor.meshgrid
|
||||
::: tinygrad.Tensor.squeeze
|
||||
::: tinygrad.Tensor.unsqueeze
|
||||
|
||||
@@ -102,7 +102,7 @@ def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
|
||||
|
||||
@torch.library.impl("aten::index_put", "privateuseone")
|
||||
def index_put(self, indices, values, accumulate=False):
|
||||
return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.cpu(), accumulate).tiny()
|
||||
return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.clone().cpu(), accumulate).tiny()
|
||||
|
||||
@torch.library.impl("aten::isin.Tensor_Tensor_out", "privateuseone")
|
||||
def isin_tensor_tensor_out(x, y, *, assume_unique=False, invert=False, out=None): return out.copy_(aten.isin(x.cpu(), y.cpu(), assume_unique=assume_unique, invert=invert).tiny())
|
||||
@@ -391,6 +391,7 @@ decomps = [
|
||||
aten.hardsigmoid_backward,
|
||||
aten.leaky_relu_backward,
|
||||
aten.nll_loss2d_forward,
|
||||
aten.unfold_backward,
|
||||
# NOTE: many of these don't work or cause infinite loops
|
||||
#aten.var_mean,
|
||||
#aten.var,
|
||||
@@ -565,7 +566,8 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.ones_like": lambda self, dtype=None, device=None, **kwargs:
|
||||
self.ones_like(**{k: v for k, v in {"dtype": _from_torch_dtype(dtype) if dtype else None,
|
||||
"device": _from_torch_device(device) if device else None}.items() if v is not None}),
|
||||
"aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64))
|
||||
"aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64)),
|
||||
"aten.unfold": Tensor.unfold,
|
||||
}}
|
||||
|
||||
def wrap_fxn(k,f):
|
||||
|
||||
@@ -227,6 +227,19 @@ class TestOps(unittest.TestCase):
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
|
||||
|
||||
def test_unfold(self):
|
||||
helper_test_op([(8,)], lambda x: x.unfold(0, 2, 1))
|
||||
helper_test_op([(8,)], lambda x: x.unfold(0, 2, 2))
|
||||
helper_test_op([(8,)], lambda x: x.unfold(0, 7, 3))
|
||||
helper_test_op([(3,3,3)], lambda x: x.unfold(2, 2, 8))
|
||||
helper_test_op([(3,3,3)], lambda x: x.unfold(1, 0, 8))
|
||||
helper_test_op([(3,3,3,3,3)], lambda x: x.unfold(-1, 2, 2))
|
||||
|
||||
self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
|
||||
self.helper_test_exception([(8,)], lambda x: x.unfold(1, 8, 3), lambda x: x.unfold(1, 8, 3), expected=IndexError)
|
||||
self.helper_test_exception([(8,)], lambda x: x.unfold(0, -1, 3), lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
|
||||
self.helper_test_exception([(8,)], lambda x: x.unfold(0, 1, -1), lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
|
||||
|
||||
def test_meshgrid(self):
|
||||
x, xt = torch.tensor([0.,1.,2.], requires_grad=True), Tensor([0.,1.,2.], requires_grad=True)
|
||||
y, yt = torch.tensor([3.,4.,5.,6.], requires_grad=True), Tensor([3.,4.,5.,6.], requires_grad=True)
|
||||
|
||||
@@ -1394,6 +1394,30 @@ class Tensor(MathTrait):
|
||||
dim = self._resolve_dim(dim)
|
||||
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
||||
|
||||
def unfold(self, dim:int, size:sint, step:int) -> Tensor:
|
||||
"""
|
||||
Unfolds the tensor along dimension `dim` into overlapping windows.
|
||||
|
||||
Each window has length `size` and begins every `step` elements of `self`.
|
||||
Returns the input tensor with dimension `dim` replaced by dims `(n_windows, size)`
|
||||
where `n_windows = (self.shape[dim] - size) // step + 1`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
unfolded = Tensor.arange(8).unfold(0,2,2)
|
||||
print("\\n".join([repr(x.numpy()) for x in unfolded]))
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
unfolded = Tensor.arange(27).reshape(3,3,3).unfold(-1,2,3)
|
||||
print("\\n".join([repr(x.numpy()) for x in unfolded]))
|
||||
```
|
||||
"""
|
||||
if size < 0: raise RuntimeError(f'size must be >= 0 but got {size=}')
|
||||
if step <= 0: raise RuntimeError(f'step must be > 0 but got {step=}')
|
||||
if size > self.shape[dim]: raise RuntimeError(f'maximum size for tensor at dimension {dim} is {self.shape[dim]} but size is {size}')
|
||||
dim = self._resolve_dim(dim)
|
||||
perm_to_last = tuple(i for i in range(self.ndim) if i != dim) + (dim,)
|
||||
return self.permute(perm_to_last)._pool((size,), step).permute(argsort(perm_to_last) + (self.ndim,))
|
||||
|
||||
def meshgrid(self:Tensor, *args:Tensor, indexing:Literal["ij", "xy"]="ij") -> tuple[Tensor, ...]:
|
||||
"""
|
||||
Generates coordinate matrices from coordinate vectors.
|
||||
|
||||
Reference in New Issue
Block a user