support list, tuple input in dtypes.from_py (#4945)

* support list, tuple input in dtypes.from_py

and used it to infer dtype from python list and tuple in Tensor constructor.

* fix tests
This commit is contained in:
chenyu
2024-06-13 13:38:06 -04:00
committed by GitHub
parent 7aecea4f56
commit 287d3c3b84
4 changed files with 28 additions and 8 deletions

View File

@@ -329,6 +329,22 @@ class TestHelpers(unittest.TestCase):
def test_scalar(self, dtype, amt):
assert dtype.vec(amt).scalar() == dtype
def test_from_py(self):
assert dtypes.from_py(True) == dtypes.bool
assert dtypes.from_py(2) == dtypes.default_int
assert dtypes.from_py(3.0) == dtypes.default_float
assert dtypes.from_py([]) == dtypes.default_float
assert dtypes.from_py(()) == dtypes.default_float
assert dtypes.from_py([True]) == dtypes.bool
assert dtypes.from_py([True, 2]) == dtypes.default_int
assert dtypes.from_py([True, 3.0]) == dtypes.default_float
assert dtypes.from_py([2, 3.0]) == dtypes.default_float
assert dtypes.from_py([True, 2, 3.0]) == dtypes.default_float
with self.assertRaises(RuntimeError): dtypes.from_py(None)
with self.assertRaises(RuntimeError): dtypes.from_py([None])
with self.assertRaises(RuntimeError): dtypes.from_py({})
with self.assertRaises(RuntimeError): dtypes.from_py(set())
class TestTypeSpec(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float

View File

@@ -337,7 +337,7 @@ class TestTinygrad(unittest.TestCase):
def test_no_bool(self):
with self.assertRaises(TypeError):
if Tensor(["3"]):
if Tensor(3):
print("hi")
with self.assertRaises(TypeError):

View File

@@ -48,8 +48,14 @@ class dtypes:
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod # NOTE: isinstance(True, int) is True in python
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
@staticmethod
def from_py(x) -> DType:
if isinstance(x, (list, tuple)): return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
if isinstance(x, float): return dtypes.default_float
# NOTE: isinstance(True, int) is True in python, so check bool before int
if isinstance(x, bool): return dtypes.bool
if isinstance(x, int): return dtypes.default_int
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
@staticmethod
def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
@staticmethod

View File

@@ -7,7 +7,7 @@ from collections import defaultdict
import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, argsort, getenv
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
from tinygrad.lazy import LazyBuffer
from tinygrad.multi import MultiLazyBuffer
@@ -109,9 +109,7 @@ class Tensor:
elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
elif isinstance(data, (list, tuple)):
if dtype is None:
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
if dtype is None: dtype = dtypes.from_py(data)
if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
else: data = _fromcpu(np.array(data, dtype.np))
elif isinstance(data, np.ndarray):
@@ -2288,7 +2286,7 @@ class Tensor:
# make y a Tensor
assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}"
if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
else: y_dtype = dtypes.from_py(y)
elif not isinstance(y, Node): y_dtype = dtypes.from_py(y)
if isinstance(y, Node): y = Tensor.from_node(y, device=self.device)
else: y = Tensor(dtypes.as_const(y, y_dtype), self.device, y_dtype, requires_grad=False)