call dtypes.as_const in Tensor(list) (#11840)

This commit is contained in:
chenyu
2025-08-25 22:08:26 -04:00
committed by GitHub
parent 215818379b
commit 337e979a59
3 changed files with 16 additions and 2 deletions

View File

@@ -415,6 +415,21 @@ class TestTinygrad(unittest.TestCase):
data = _generate_data(depth)
np.testing.assert_allclose(Tensor(data).numpy(), np.array(data))
def test_tensor_list_implicit_cast(self):
data = [True, False]
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
data = [-1, 0, 1, 2, 3]
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
data = [-3.5, -2.5, -1.5, 0, 1.5, 2.5, 3.5]
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
# NOTE: torch and jax raise OverflowError: Python integer -3 out of bounds for uint8
# np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
def test_tensor_list_special_values(self):
if is_dtype_supported(dtypes.float16):
data = [math.nan, -math.inf, 65504, 65519, 65519.999, 65520, 65520.1]

View File

@@ -108,7 +108,6 @@ class dtypes:
if isinstance(val, tuple):
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
return tuple(dtypes.as_const(x, dtype) for x in val)
# TODO: should truncate here
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
@staticmethod
@functools.cache

View File

@@ -68,7 +68,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp:
ret = UOp.new_buffer("PYTHON", prod(shape:=get_shape(x)), dtype).reshape(shape)
assert dtype.fmt is not None, f"{dtype=} has None fmt"
truncate_function = truncate[dtype]
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
data = struct.pack(f"{ret.size}{dtype.fmt}", *[truncate_function(dtypes.as_const(xi, dtype)) for xi in fully_flatten(x)])
# fake realize
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
return ret