From fde9f0e60de64d59162cd65987456fbe32f6bfb5 Mon Sep 17 00:00:00 2001 From: Umut Zengin <70948490+ZenginU@users.noreply.github.com> Date: Wed, 19 Jul 2023 19:08:38 +0300 Subject: [PATCH] Slice migrated in Eye op (#1281) * Migrated from slice to pad and shrink, made cleaner * Changed repeat with reshape and expand --- test/test_ops.py | 1 + tinygrad/tensor.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 20ab51c8b1..1d4c8f67a2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -96,6 +96,7 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True) def test_eye(self): helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) + helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True) def test_arange(self): helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 977f182d4d..c4d94dbaa7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -168,7 +168,9 @@ class Tensor: def ones_like(tensor, **kwargs): return Tensor.full_like(tensor, 1, **kwargs) @staticmethod - def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim) + def eye(dim:int, **kwargs): + return Tensor([1], **kwargs).pad(((0,dim),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim) + # ***** rng hlops *****