From 287d3c3b84e95797a999e9c775948211bd926871 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 13 Jun 2024 13:38:06 -0400 Subject: [PATCH] support list, tuple input in dtypes.from_py (#4945) * support list, tuple input in dtypes.from_py and used it to infer dtype from python list and tuple in Tensor constructor. * fix tests --- test/test_dtype.py | 16 ++++++++++++++++ test/test_tensor.py | 2 +- tinygrad/dtype.py | 10 ++++++++-- tinygrad/tensor.py | 8 +++----- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 7a7b3af546..2988805f61 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -329,6 +329,22 @@ class TestHelpers(unittest.TestCase): def test_scalar(self, dtype, amt): assert dtype.vec(amt).scalar() == dtype + def test_from_py(self): + assert dtypes.from_py(True) == dtypes.bool + assert dtypes.from_py(2) == dtypes.default_int + assert dtypes.from_py(3.0) == dtypes.default_float + assert dtypes.from_py([]) == dtypes.default_float + assert dtypes.from_py(()) == dtypes.default_float + assert dtypes.from_py([True]) == dtypes.bool + assert dtypes.from_py([True, 2]) == dtypes.default_int + assert dtypes.from_py([True, 3.0]) == dtypes.default_float + assert dtypes.from_py([2, 3.0]) == dtypes.default_float + assert dtypes.from_py([True, 2, 3.0]) == dtypes.default_float + with self.assertRaises(RuntimeError): dtypes.from_py(None) + with self.assertRaises(RuntimeError): dtypes.from_py([None]) + with self.assertRaises(RuntimeError): dtypes.from_py({}) + with self.assertRaises(RuntimeError): dtypes.from_py(set()) + class TestTypeSpec(unittest.TestCase): def setUp(self): self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float diff --git a/test/test_tensor.py b/test/test_tensor.py index 453112d7bb..a00a770a52 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -337,7 +337,7 @@ class TestTinygrad(unittest.TestCase): def test_no_bool(self): with self.assertRaises(TypeError): - if Tensor(["3"]): + if Tensor(3): print("hi") with self.assertRaises(TypeError): diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index d5b0d90216..20deb07102 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -48,8 +48,14 @@ class dtypes: def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) @staticmethod def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name] - @staticmethod # NOTE: isinstance(True, int) is True in python - def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int + @staticmethod + def from_py(x) -> DType: + if isinstance(x, (list, tuple)): return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float + if isinstance(x, float): return dtypes.default_float + # NOTE: isinstance(True, int) is True in python, so check bool before int + if isinstance(x, bool): return dtypes.bool + if isinstance(x, int): return dtypes.default_int + raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}") @staticmethod def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val) @staticmethod diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 596c9adb8b..22c2135edd 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ from collections import defaultdict import numpy as np from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype -from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, argsort, getenv +from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY from tinygrad.lazy import LazyBuffer from tinygrad.multi import MultiLazyBuffer @@ -109,9 +109,7 @@ class Tensor: elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data) elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8)) elif isinstance(data, (list, tuple)): - if dtype is None: - if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool - else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float + if dtype is None: dtype = dtypes.from_py(data) if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata else: data = _fromcpu(np.array(data, dtype.np)) elif isinstance(data, np.ndarray): @@ -2288,7 +2286,7 @@ class Tensor: # make y a Tensor assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}" if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype - else: y_dtype = dtypes.from_py(y) + elif not isinstance(y, Node): y_dtype = dtypes.from_py(y) if isinstance(y, Node): y = Tensor.from_node(y, device=self.device) else: y = Tensor(dtypes.as_const(y, y_dtype), self.device, y_dtype, requires_grad=False)