From 6add808f6abd524ad2a7f2642263ce7420ccdf72 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 19 Nov 2023 20:20:39 -0500 Subject: [PATCH] support tuple shape input for rand and empty (#2367) --- test/test_tensor.py | 6 ++++++ tinygrad/tensor.py | 7 +++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index 4327b3de41..39572cdc84 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -205,6 +205,12 @@ class TestTinygrad(unittest.TestCase): self.assertEqual(Tensor.zeros([10,20,40]).shape, (10,20,40)) self.assertEqual(Tensor.ones([10,20,40]).shape, (10,20,40)) + self.assertEqual(Tensor.rand(1,10,20).shape, (1,10,20)) + self.assertEqual(Tensor.rand((10,20,40)).shape, (10,20,40)) + + self.assertEqual(Tensor.empty(1,10,20).shape, (1,10,20)) + self.assertEqual(Tensor.empty((10,20,40)).shape, (10,20,40)) + def test_numel(self): assert Tensor.randn(10, 10).numel() == 100 assert Tensor.randn(1,2,5).numel() == 10 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d938a67dec..d8fcce92ba 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -133,12 +133,12 @@ class Tensor: @staticmethod def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs): + assert isinstance(sz, int), f"cannot create with symbolic size {sz}" return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) @staticmethod def empty(*shape, **kwargs): - assert all_int(shape), f"cannot create with symbolic shape {shape}" - return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape) + return Tensor._loadop(LoadOps.EMPTY, prod((shape:=argfix(*shape))), **kwargs).reshape(shape) _seed: int = int(time.time()) @staticmethod @@ -146,9 +146,8 @@ class Tensor: @staticmethod def rand(*shape, **kwargs): - assert all_int(shape), f"cannot create with symbolic shape {shape}" Tensor._seed += 1 - return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape) + return Tensor._loadop(LoadOps.RAND, prod((shape:=argfix(*shape))), arg=Tensor._seed, **kwargs).reshape(shape) # ***** creation helper functions *****