From 7aecea4f5694b09f783c2a96eb8558494edda724 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 13 Jun 2024 12:18:37 -0400 Subject: [PATCH] support creating Tensor from python tuple (#4944) added a small fuzzer to test data with mixed tuple and list of numbers matched with numpy --- test/test_tensor.py | 14 +++++++++++++- tinygrad/tensor.py | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index 025731037d..453112d7bb 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,6 +1,6 @@ import numpy as np import torch -import unittest, copy, mmap +import unittest, copy, mmap, random from tinygrad import Tensor, Device, dtypes from tinygrad.helpers import getenv, temp, CI from extra.gradcheck import numerical_jacobian, jacobian, gradcheck @@ -288,6 +288,18 @@ class TestTinygrad(unittest.TestCase): with self.assertRaises(ValueError): Tensor([[[1,1,1],[1,1]]]) with self.assertRaises(ValueError): Tensor([[1,1,1],[[1,1,1]]]) + def test_tensor_mixed_list_tuple(self): + def _list_or_tuple(): return list if random.random() < 0.5 else tuple + def _generate_data(depth): + if depth == 0: return _list_or_tuple()() + if depth == 1: return _list_or_tuple()([random.random(), random.random()]) + return _list_or_tuple()([_generate_data(depth-1), _generate_data(depth-1)]) + + for depth in range(7): + for _ in range(20): + data = _generate_data(depth) + np.testing.assert_allclose(Tensor(data).numpy(), np.array(data)) + def test_tensor_copy(self): x = copy.deepcopy(Tensor.ones((3,3,3))) np.testing.assert_allclose(x.numpy(), np.ones((3,3,3))) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d5ee96070a..596c9adb8b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -108,7 +108,7 @@ class Tensor: elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data) elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8)) - elif isinstance(data, list): + elif isinstance(data, (list, tuple)): if dtype is None: if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float