From 502e33652f25cfcd0662d3200cfe7939c43d0056 Mon Sep 17 00:00:00 2001 From: Ubaidullah Khan <35372700+ubihot@users.noreply.github.com> Date: Tue, 30 May 2023 02:48:09 +0200 Subject: [PATCH] add Tensor.full and Tensor.full_like and reuse them (#852) * add Tensor.ones_like() * add full_like and full and reuse in zeros,ones * add tests for full and full_like --- test/test_ops.py | 14 ++++++++++++++ test/test_tensor.py | 12 ++++++++++++ tinygrad/tensor.py | 18 +++++++++++++++--- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a7bfaeaa9c..8862e54a91 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -54,12 +54,26 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="") class TestOps(unittest.TestCase): + def test_full_like(self): + a = Tensor([[1,2,3],[4,5,6]]) + b = torch.tensor([[1,2,3],[4,5,6]]) + helper_test_op([], lambda: torch.full_like(b, 4), lambda: Tensor.full_like(a, 4), forward_only=True) + def test_full(self): + helper_test_op([], lambda: torch.full((45,65), 4), lambda: Tensor.full((45,65), 4), forward_only=True) def test_zeros(self): helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True) + def test_zeros_like(self): + a = Tensor([[1,2,3],[4,5,6]]) + b = torch.tensor([[1,2,3],[4,5,6]]) + helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True) def test_empty_0(self): helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True) def test_ones(self): helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True) + def test_ones_like(self): + a = Tensor([[1,2,3],[4,5,6]]) + b = torch.tensor([[1,2,3],[4,5,6]]) + helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True) def test_eye(self): helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) def test_arange(self): diff --git a/test/test_tensor.py b/test/test_tensor.py index 014e6e3505..4164d2431b 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -166,6 +166,18 @@ class TestTinygrad(unittest.TestCase): b = Tensor.zeros_like(a, dtype=dtypes.int8) assert a.dtype != b.dtype and a.dtype == dtypes.float32 and b.dtype == dtypes.int8, "a.dtype should be float and b.dtype should be char" assert a.shape == b.shape, f"shape mismatch (Tensor.zeros_like){a.shape} != (torch){b.shape}" + + def test_ones_like_has_same_dtype_and_shape(self): + for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]: + a = Tensor([1, 2, 3], dtype=datatype) + b = Tensor.ones_like(a) + assert a.dtype == b.dtype, f"a.dtype and b.dtype should be {datatype}" + assert a.shape == b.shape, f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}" + + a = Tensor([1, 2, 3]) + b = Tensor.ones_like(a, dtype=dtypes.int8) + assert a.dtype != b.dtype and a.dtype == dtypes.float32 and b.dtype == dtypes.int8, "a.dtype should be float and b.dtype should be char" + assert a.shape == b.shape, f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}" if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c45590b149..fed02fee44 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -122,13 +122,25 @@ class Tensor: # ***** creation helper functions ***** @staticmethod - def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous() + def full(shape:Tuple[int, ...], fill_value, **kwargs): + new_shape = argfix(shape) + return Tensor([fill_value], **kwargs).reshape([1]*len(new_shape)).expand(new_shape).contiguous() @staticmethod - def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous() + def zeros(*shape, **kwargs): return Tensor.full(shape, 0, **kwargs) @staticmethod - def zeros_like(tensor, dtype:Optional[DType]=None, **kwargs): return Tensor.zeros(*tensor.shape, dtype=tensor.dtype if dtype is None else dtype, **kwargs) + def ones(*shape, **kwargs): return Tensor.full(shape, 1, **kwargs) + + @staticmethod + def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs): + return Tensor.full(tensor.shape, fill_value, dtype=tensor.dtype if dtype is None else dtype, **kwargs) + + @staticmethod + def zeros_like(tensor, **kwargs): return Tensor.full_like(tensor, 0, **kwargs) + + @staticmethod + def ones_like(tensor, **kwargs): return Tensor.full_like(tensor, 1, **kwargs) @staticmethod def empty(*shape, device=Device.DEFAULT, dtype:Optional[DType]=None, **kwargs):