diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 92bf5f14bf..5298cca3f1 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -205,7 +205,6 @@ class TestDiskTensor(unittest.TestCase): helper_test_disk_tensor("dt5", [1,2,3,4,5], lambda x: x.reshape((1,5))) helper_test_disk_tensor("dt6", [1,2,3,4], lambda x: x.reshape((2,2))) - @unittest.expectedFailure def test_assign_to_different_dtype(self): # NOTE: this is similar to Y_train in fetch_cifar t = Tensor.empty(10, device=f'disk:{temp("dt7")}', dtype=dtypes.int64) diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 281a2f2716..71179d5a01 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -64,12 +64,11 @@ class DiskRunner(JITRunner): assert view.mask is None, "view cannot have a mask" assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides" self.new_size = prod(view.shape) - self.new_offset = view.offset + self.new_offset = view.offset * top_src.arg.dtype.itemsize def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False, jit=False): assert len(rawbufs) == 2 src = rawbufs[1]._buf - # TODO: src.dtype.itemsize or self.new_dtype.itemsize? - rawbufs[0]._buf = DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset*src.dtype.itemsize) + rawbufs[0]._buf = DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset) class DiskDevice(Compiled): def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), None, None)