diff --git a/test/test_tensor.py b/test/test_tensor.py index 199fe3a6f2..4a1a9047e8 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,9 +1,11 @@ import numpy as np import torch +import struct import unittest, copy from tinygrad.tensor import Tensor, Device from tinygrad.helpers import dtypes from extra.gradcheck import numerical_jacobian, jacobian, gradcheck +from extra.utils import temp x_init = np.random.randn(1,3).astype(np.float32) U_init = np.random.randn(3,3).astype(np.float32) @@ -240,5 +242,11 @@ class TestTinygrad(unittest.TestCase): x = copy.deepcopy(Tensor.ones((3,3,3))) np.testing.assert_allclose(x.numpy(), np.ones((3,3,3))) + def test_copy_from_disk(self): + t = Tensor.randn(30, device="CPU").to(f"disk:{temp('test_copy_from_disk')}") + a = t[10:20] + dev = a.to(Device.DEFAULT) + np.testing.assert_allclose(a.numpy(), dev.numpy()) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index d085093b13..8165f9a2dd 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -235,12 +235,14 @@ class LazyBuffer: def fromCPU(x: np.ndarray) -> LazyBuffer: return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x)) + def prepare_transfer(self): + self_casted = self.e(UnaryOps.CAST, arg=(dtypes.from_np(self.dtype.np), False)) if dtypes.from_np(self.dtype.np) != self.dtype else self + return self_casted.contiguous().realize().realized + def toCPU(self) -> np.ndarray: assert self.dtype.np, f"{self.dtype} is not supported in toCPU" - self_casted = self.e(UnaryOps.CAST, arg=(dtypes.from_np(self.dtype.np), False)) if dtypes.from_np(self.dtype.np) != self.dtype else self - realized = self_casted.contiguous().realize().realized assert all_int(self.shape), f"no toCPU if shape is symbolic, {self.shape=}" - return cast(RawBuffer, realized).toCPU().reshape(self.shape) + return cast(RawBuffer, self.prepare_transfer()).toCPU().reshape(self.shape) # *** elementwise ops *** @@ -410,7 +412,7 @@ def _realize_from(buffer: LazyBuffer) -> None: if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped): assert all_int(buffer.shape), "does not support symbolic shape" buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args()) - rawbuf.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer()) + rawbuf.prepare_transfer().readinto(cast(RawBufferMapped, buffer.realized)._buffer()) elif isinstance(rawbuf.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and P2P >= 1: buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(rawbuf.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args()) else: