fix disktensor offset issue (#3532)

This commit is contained in:
George Hotz
2024-02-28 17:22:17 -08:00
committed by GitHub
parent 0b1fc5888a
commit 48918fa75a
2 changed files with 2 additions and 4 deletions

View File

@@ -64,12 +64,11 @@ class DiskRunner(JITRunner):
assert view.mask is None, "view cannot have a mask"
assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides"
self.new_size = prod(view.shape)
self.new_offset = view.offset
self.new_offset = view.offset * top_src.arg.dtype.itemsize
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False, jit=False):
assert len(rawbufs) == 2
src = rawbufs[1]._buf
# TODO: src.dtype.itemsize or self.new_dtype.itemsize?
rawbufs[0]._buf = DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset*src.dtype.itemsize)
rawbufs[0]._buf = DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset)
class DiskDevice(Compiled):
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), None, None)