mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
call dtypes.as_const in Tensor(list) (#11840)
This commit is contained in:
@@ -415,6 +415,21 @@ class TestTinygrad(unittest.TestCase):
|
||||
data = _generate_data(depth)
|
||||
np.testing.assert_allclose(Tensor(data).numpy(), np.array(data))
|
||||
|
||||
def test_tensor_list_implicit_cast(self):
|
||||
data = [True, False]
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
|
||||
data = [-1, 0, 1, 2, 3]
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
|
||||
data = [-3.5, -2.5, -1.5, 0, 1.5, 2.5, 3.5]
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy())
|
||||
# NOTE: torch and jax raise OverflowError: Python integer -3 out of bounds for uint8
|
||||
# np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy())
|
||||
np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy())
|
||||
|
||||
def test_tensor_list_special_values(self):
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
data = [math.nan, -math.inf, 65504, 65519, 65519.999, 65520, 65520.1]
|
||||
|
||||
@@ -108,7 +108,6 @@ class dtypes:
|
||||
if isinstance(val, tuple):
|
||||
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
|
||||
return tuple(dtypes.as_const(x, dtype) for x in val)
|
||||
# TODO: should truncate here
|
||||
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
|
||||
@@ -68,7 +68,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp:
|
||||
ret = UOp.new_buffer("PYTHON", prod(shape:=get_shape(x)), dtype).reshape(shape)
|
||||
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
||||
truncate_function = truncate[dtype]
|
||||
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
||||
data = struct.pack(f"{ret.size}{dtype.fmt}", *[truncate_function(dtypes.as_const(xi, dtype)) for xi in fully_flatten(x)])
|
||||
# fake realize
|
||||
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user