Shape changing bitcast and assert bitcast in disk (#3973)

* Shape changing bitcast

* only support it on disk

* basic test

* more tests

* RuntimeError instead of assert

* create unique temp files

* move tests that use disk to test_disk_tensor

* linter

* remove assert on error messages

* that's RuntimeError now

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
uuuvn
2024-03-29 06:49:10 +02:00
committed by GitHub
parent 793ab0512e
commit 8a40d7d423
6 changed files with 56 additions and 11 deletions

View File

@@ -79,10 +79,18 @@ class LazyBuffer:
def cast(self, dtype:DType, bitcast:bool=False):
if self.dtype == dtype: return self
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
# TODO: applying this makes gpt2 slower
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
return self.base.cast(dtype, bitcast)._view(self.st)
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
new_shape = self.shape
if bitcast and self.dtype.itemsize != dtype.itemsize:
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST

View File

@@ -14,7 +14,7 @@ inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:1].cast(dtypes.int64).item()
json_len = t[0:8].bitcast(dtypes.int64).item()
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
@@ -23,8 +23,8 @@ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
for k,v in metadata.items():
if k == "__metadata__": continue
dtype = safe_dtypes[v['dtype']]
sz = (v['data_offsets'][1]-v['data_offsets'][0])//dtype.itemsize
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].cast(dtype).reshape(v['shape'])
sz = (v['data_offsets'][1]-v['data_offsets'][0])
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
return ret
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
@@ -37,7 +37,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
j += "\x20"*((8-len(j)%8)%8)
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:1].cast(dtypes.int64).assign([len(j)])
t[0:8].bitcast(dtypes.int64).assign([len(j)])
t[8:8+len(j)].assign(list(j.encode('utf-8')))
for k,v in safe_load(t).items(): v.assign(tensors[k])
@@ -83,7 +83,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]

View File

@@ -53,7 +53,7 @@ class DiskRunner(JITRunner):
# TODO: there shouldn't actually be casts here, bitcasts should fold into the load
if ast.src[0].op == UnaryOps.CAST:
top_src = ast.src[0].src[0]
# TODO: assert that this is bitcast
assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts"
self.new_dtype = ast.src[0].arg[0]
else:
top_src = ast.src[0]

View File

@@ -1033,7 +1033,7 @@ class Tensor:
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else mlops.Cast.apply(self, dtype=dtype)
def bitcast(self, dtype:DType) -> Tensor:
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
def float(self) -> Tensor: return self.cast(dtypes.float32)
def half(self) -> Tensor: return self.cast(dtypes.float16)