diff --git a/test/test_ops.py b/test/test_ops.py index 20ab51c8b1..1d4c8f67a2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -96,6 +96,7 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True) def test_eye(self): helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) + helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True) def test_arange(self): helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 977f182d4d..c4d94dbaa7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -168,7 +168,9 @@ class Tensor: def ones_like(tensor, **kwargs): return Tensor.full_like(tensor, 1, **kwargs) @staticmethod - def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim) + def eye(dim:int, **kwargs): + return Tensor([1], **kwargs).pad(((0,dim),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim) + # ***** rng hlops *****