mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
Shape changing bitcast and assert bitcast in disk (#3973)
* Shape changing bitcast * only support it on disk * basic test * more tests * RuntimeError instead of assert * create unique temp files * move tests that use disk to test_disk_tensor * linter * remove assert on error messages * that's RuntimeError now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -79,10 +79,18 @@ class LazyBuffer:
|
||||
|
||||
def cast(self, dtype:DType, bitcast:bool=False):
|
||||
if self.dtype == dtype: return self
|
||||
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
||||
# TODO: applying this makes gpt2 slower
|
||||
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
|
||||
new_shape = self.shape
|
||||
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
||||
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
||||
if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
|
||||
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
|
||||
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
|
||||
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
|
||||
|
||||
def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST
|
||||
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST
|
||||
|
||||
@@ -14,7 +14,7 @@ inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:1].cast(dtypes.int64).item()
|
||||
json_len = t[0:8].bitcast(dtypes.int64).item()
|
||||
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
@@ -23,8 +23,8 @@ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
for k,v in metadata.items():
|
||||
if k == "__metadata__": continue
|
||||
dtype = safe_dtypes[v['dtype']]
|
||||
sz = (v['data_offsets'][1]-v['data_offsets'][0])//dtype.itemsize
|
||||
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].cast(dtype).reshape(v['shape'])
|
||||
sz = (v['data_offsets'][1]-v['data_offsets'][0])
|
||||
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
|
||||
return ret
|
||||
|
||||
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
||||
@@ -37,7 +37,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
|
||||
j += "\x20"*((8-len(j)%8)%8)
|
||||
pathlib.Path(fn).unlink(missing_ok=True)
|
||||
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
t[0:1].cast(dtypes.int64).assign([len(j)])
|
||||
t[0:8].bitcast(dtypes.int64).assign([len(j)])
|
||||
t[8:8+len(j)].assign(list(j.encode('utf-8')))
|
||||
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
||||
|
||||
@@ -83,7 +83,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
lens[storage[2]] = storage[4] * storage[1].itemsize
|
||||
if storage[2] not in offsets: return None
|
||||
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
||||
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
|
||||
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
|
||||
|
||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||
|
||||
@@ -53,7 +53,7 @@ class DiskRunner(JITRunner):
|
||||
# TODO: there shouldn't actually be casts here, bitcasts should fold into the load
|
||||
if ast.src[0].op == UnaryOps.CAST:
|
||||
top_src = ast.src[0].src[0]
|
||||
# TODO: assert that this is bitcast
|
||||
assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts"
|
||||
self.new_dtype = ast.src[0].arg[0]
|
||||
else:
|
||||
top_src = ast.src[0]
|
||||
|
||||
@@ -1033,7 +1033,7 @@ class Tensor:
|
||||
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
||||
def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else mlops.Cast.apply(self, dtype=dtype)
|
||||
def bitcast(self, dtype:DType) -> Tensor:
|
||||
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
|
||||
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
||||
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
||||
def float(self) -> Tensor: return self.cast(dtypes.float32)
|
||||
def half(self) -> Tensor: return self.cast(dtypes.float16)
|
||||
|
||||
Reference in New Issue
Block a user