create Tensor from bytes without numpy (#4964)

This commit is contained in:
chenyu
2024-06-14 13:37:27 -04:00
committed by GitHub
parent 5eee974b2a
commit dae1c8abe2
2 changed files with 15 additions and 5 deletions

View File

@@ -300,6 +300,13 @@ class TestTinygrad(unittest.TestCase):
data = _generate_data(depth)
np.testing.assert_allclose(Tensor(data).numpy(), np.array(data))
def test_tensor_bytes(self):
data = b"abc123"
t = Tensor(data)
assert t.dtype == dtypes.uint8
assert t.shape == (6,)
np.testing.assert_equal(t.numpy(), list(data))
def test_tensor_copy(self):
x = copy.deepcopy(Tensor.ones((3,3,3)))
np.testing.assert_allclose(x.numpy(), np.ones((3,3,3)))

View File

@@ -50,11 +50,14 @@ def _fromcpu(x: np.ndarray) -> LazyBuffer:
del ret.srcs
return ret
def _frompy(x:Union[List, Tuple], dtype:DType) -> LazyBuffer:
ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
if isinstance(x, bytes): ret, data = LazyBuffer.loadop(LoadOps.EMPTY, (len(x),), dtype, "PYTHON"), x
else:
ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
assert dtype.fmt is not None, f"{dtype=} has None fmt"
data = struct.pack(f"@{ret.size}{dtype.fmt}", *fully_flatten(x))
# fake realize
assert dtype.fmt is not None, f"{dtype=} has None fmt"
ret.buffer.allocate(memoryview(struct.pack(f"@{ret.size}{dtype.fmt}", *fully_flatten(x))))
ret.buffer.allocate(memoryview(data))
del ret.srcs
return ret
@@ -115,7 +118,7 @@ class Tensor:
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8)
elif isinstance(data, (list, tuple)):
if dtype is None: dtype = dtypes.from_py(data)
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata