diff --git a/test/test_dtype.py b/test/test_dtype.py index 9254b4ac07..6311b4046e 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -249,7 +249,7 @@ class TestUint8Dtype(TestDType): @unittest.skipIf(Device.DEFAULT == "WEBGL", "No bitcast on WebGL") class TestBitCast(unittest.TestCase): def test_shape_change_bitcast(self): - with self.assertRaises(AssertionError): + with self.assertRaises(RuntimeError): _test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000]) def test_bitcast_float_to_int32(self): diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 1929296a57..643a7593f2 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -1,6 +1,7 @@ -import pathlib, unittest +import pathlib, tempfile, unittest import numpy as np from tinygrad import Tensor, Device, dtypes +from tinygrad.dtype import DType from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load from tinygrad.helpers import Timing, fetch, temp from test.helpers import is_dtype_supported @@ -36,14 +37,50 @@ test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00. #test_size = test_fn.stat().st_size test_size = 1024*1024*1024*2 +def _test_bitcasted(t: Tensor, dt: DType, expected): + np.testing.assert_allclose(t.bitcast(dt).numpy(), expected) + # sudo su -c 'sync; echo 1 > /proc/sys/vm/drop_caches' && python3 test/unit/test_disk_tensor.py TestRawDiskBuffer.test_readinto_read_speed -@unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests") class TestRawDiskBuffer(unittest.TestCase): + @unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests") def test_readinto_read_speed(self): tst = np.empty(test_size, np.uint8) with open(test_fn, "rb") as f: with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"): f.readinto(tst) + def test_bitcasts_on_disk(self): + tmp = tempfile.mktemp() + # ground truth = https://evanw.github.io/float-toy/ + t = Tensor.empty((128, 128), dtype=dtypes.uint8, device=f"disk:{tmp}") # uint8 + # all zeroes + _test_bitcasted(t, dtypes.float16, 0.0) + _test_bitcasted(t, dtypes.uint16, 0) + _test_bitcasted(t, dtypes.float32, 0.0) + _test_bitcasted(t, dtypes.uint32, 0) + # pi in float16 stored via int16 + t.bitcast(dtypes.uint16).assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16)).realize() + _test_bitcasted(t, dtypes.float16, 3.141) + _test_bitcasted(t, dtypes.float32, 50.064727) + _test_bitcasted(t, dtypes.uint16, 0x4248) + _test_bitcasted(t, dtypes.uint32, 0x42484248) + # pi in float32 stored via float32 + t.bitcast(dtypes.float32).assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32)).realize() + _test_bitcasted(t, dtypes.float32, 3.1415927) + _test_bitcasted(t, dtypes.uint32, 0x40490FDB) + # doesn't suport normal cast + with self.assertRaises(RuntimeError): + Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16) + + # Those two should be moved to test_dtype.py:test_shape_change_bitcast after bitcast works on non-disk + with self.assertRaises(RuntimeError): + # should fail because 3 int8 is 3 bytes but float16 is two and 3 isn't a multiple of 2 + Tensor.empty((3,), dtype=dtypes.int8, device=f"DISK:{tmp}").bitcast(dtypes.float16) + + with self.assertRaises(RuntimeError): + # should fail because backprop through bitcast is undefined + Tensor.empty((4,), dtype=dtypes.int8, requires_grad=True, device=f"DISK:{tmp}").bitcast(dtypes.float16) + + pathlib.Path(tmp).unlink() @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype") class TestSafetensors(unittest.TestCase): diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 31c2655be2..f76971a1a3 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -79,10 +79,18 @@ class LazyBuffer: def cast(self, dtype:DType, bitcast:bool=False): if self.dtype == dtype: return self + if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)") # TODO: applying this makes gpt2 slower if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base: return self.base.cast(dtype, bitcast)._view(self.st) - return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,)) + new_shape = self.shape + if bitcast and self.dtype.itemsize != dtype.itemsize: + if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now") + if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet") + # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html + if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast") + new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,) + return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,)) def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index eb98b586f2..38a9713996 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -14,7 +14,7 @@ inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]: t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") - json_len = t[0:1].cast(dtypes.int64).item() + json_len = t[0:8].bitcast(dtypes.int64).item() return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()) def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]: @@ -23,8 +23,8 @@ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]: for k,v in metadata.items(): if k == "__metadata__": continue dtype = safe_dtypes[v['dtype']] - sz = (v['data_offsets'][1]-v['data_offsets'][0])//dtype.itemsize - ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].cast(dtype).reshape(v['shape']) + sz = (v['data_offsets'][1]-v['data_offsets'][0]) + ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape']) return ret def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None): @@ -37,7 +37,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any j += "\x20"*((8-len(j)%8)%8) pathlib.Path(fn).unlink(missing_ok=True) t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}") - t[0:1].cast(dtypes.int64).assign([len(j)]) + t[0:8].bitcast(dtypes.int64).assign([len(j)]) t[8:8+len(j)].assign(list(j.encode('utf-8'))) for k,v in safe_load(t).items(): v.assign(tensors[k]) @@ -83,7 +83,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]: lens[storage[2]] = storage[4] * storage[1].itemsize if storage[2] not in offsets: return None byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize - ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1]) + ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1]) # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1] diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 3903666a3a..9691aba5a3 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -53,7 +53,7 @@ class DiskRunner(JITRunner): # TODO: there shouldn't actually be casts here, bitcasts should fold into the load if ast.src[0].op == UnaryOps.CAST: top_src = ast.src[0].src[0] - # TODO: assert that this is bitcast + assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts" self.new_dtype = ast.src[0].arg[0] else: top_src = ast.src[0] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8c1e71f95f..e0844fec63 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1033,7 +1033,7 @@ class Tensor: return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype) def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else mlops.Cast.apply(self, dtype=dtype) def bitcast(self, dtype:DType) -> Tensor: - assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes" + if self.requires_grad: raise RuntimeError("can't backprop through bitcast") return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self def float(self) -> Tensor: return self.cast(dtypes.float32) def half(self) -> Tensor: return self.cast(dtypes.float16)