mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
call dtypes.as_const in Tensor(list) (#11840)
This commit is contained in:
@@ -415,6 +415,21 @@ class TestTinygrad(unittest.TestCase):
|
||||
data = _generate_data(depth)
|
||||
np.testing.assert_allclose(Tensor(data).numpy(), np.array(data))
|
||||
|
||||
def test_tensor_list_implicit_cast(self):
|
||||
data = [True, False]
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
|
||||
data = [-1, 0, 1, 2, 3]
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
|
||||
data = [-3.5, -2.5, -1.5, 0, 1.5, 2.5, 3.5]
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
|
||||
# NOTE: torch and jax raise OverflowError: Python integer -3 out of bounds for uint8
|
||||
# np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
|
||||
|
||||
def test_tensor_list_special_values(self):
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
data = [math.nan, -math.inf, 65504, 65519, 65519.999, 65520, 65520.1]
|
||||
|
||||
Reference in New Issue
Block a user