call dtypes.as_const in Tensor(list) (#11840)

This commit is contained in:
chenyu
2025-08-25 22:08:26 -04:00
committed by GitHub
parent 215818379b
commit 337e979a59
3 changed files with 16 additions and 2 deletions

View File

@@ -415,6 +415,21 @@ class TestTinygrad(unittest.TestCase):
data = _generate_data(depth)
np.testing.assert_allclose(Tensor(data).numpy(), np.array(data))
def test_tensor_list_implicit_cast(self):
data = [True, False]
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
data = [-1, 0, 1, 2, 3]
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
data = [-3.5, -2.5, -1.5, 0, 1.5, 2.5, 3.5]
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
# NOTE: torch and jax raise OverflowError: Python integer -3 out of bounds for uint8
# np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
def test_tensor_list_special_values(self):
if is_dtype_supported(dtypes.float16):
data = [math.nan, -math.inf, 65504, 65519, 65519.999, 65520, 65520.1]