mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Shape changing bitcast and assert bitcast in disk (#3973)
* Shape changing bitcast * only support it on disk * basic test * more tests * RuntimeError instead of assert * create unique temp files * move tests that use disk to test_disk_tensor * linter * remove assert on error messages * that's RuntimeError now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -249,7 +249,7 @@ class TestUint8Dtype(TestDType):
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGL", "No bitcast on WebGL")
|
||||
class TestBitCast(unittest.TestCase):
|
||||
def test_shape_change_bitcast(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
|
||||
|
||||
def test_bitcast_float_to_int32(self):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pathlib, unittest
|
||||
import pathlib, tempfile, unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import Timing, fetch, temp
|
||||
from test.helpers import is_dtype_supported
|
||||
@@ -36,14 +37,50 @@ test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.
|
||||
#test_size = test_fn.stat().st_size
|
||||
test_size = 1024*1024*1024*2
|
||||
|
||||
def _test_bitcasted(t: Tensor, dt: DType, expected):
|
||||
np.testing.assert_allclose(t.bitcast(dt).numpy(), expected)
|
||||
|
||||
# sudo su -c 'sync; echo 1 > /proc/sys/vm/drop_caches' && python3 test/unit/test_disk_tensor.py TestRawDiskBuffer.test_readinto_read_speed
|
||||
@unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests")
|
||||
class TestRawDiskBuffer(unittest.TestCase):
|
||||
@unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests")
|
||||
def test_readinto_read_speed(self):
|
||||
tst = np.empty(test_size, np.uint8)
|
||||
with open(test_fn, "rb") as f:
|
||||
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
|
||||
f.readinto(tst)
|
||||
def test_bitcasts_on_disk(self):
|
||||
tmp = tempfile.mktemp()
|
||||
# ground truth = https://evanw.github.io/float-toy/
|
||||
t = Tensor.empty((128, 128), dtype=dtypes.uint8, device=f"disk:{tmp}") # uint8
|
||||
# all zeroes
|
||||
_test_bitcasted(t, dtypes.float16, 0.0)
|
||||
_test_bitcasted(t, dtypes.uint16, 0)
|
||||
_test_bitcasted(t, dtypes.float32, 0.0)
|
||||
_test_bitcasted(t, dtypes.uint32, 0)
|
||||
# pi in float16 stored via int16
|
||||
t.bitcast(dtypes.uint16).assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16)).realize()
|
||||
_test_bitcasted(t, dtypes.float16, 3.141)
|
||||
_test_bitcasted(t, dtypes.float32, 50.064727)
|
||||
_test_bitcasted(t, dtypes.uint16, 0x4248)
|
||||
_test_bitcasted(t, dtypes.uint32, 0x42484248)
|
||||
# pi in float32 stored via float32
|
||||
t.bitcast(dtypes.float32).assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32)).realize()
|
||||
_test_bitcasted(t, dtypes.float32, 3.1415927)
|
||||
_test_bitcasted(t, dtypes.uint32, 0x40490FDB)
|
||||
# doesn't suport normal cast
|
||||
with self.assertRaises(RuntimeError):
|
||||
Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16)
|
||||
|
||||
# Those two should be moved to test_dtype.py:test_shape_change_bitcast after bitcast works on non-disk
|
||||
with self.assertRaises(RuntimeError):
|
||||
# should fail because 3 int8 is 3 bytes but float16 is two and 3 isn't a multiple of 2
|
||||
Tensor.empty((3,), dtype=dtypes.int8, device=f"DISK:{tmp}").bitcast(dtypes.float16)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
# should fail because backprop through bitcast is undefined
|
||||
Tensor.empty((4,), dtype=dtypes.int8, requires_grad=True, device=f"DISK:{tmp}").bitcast(dtypes.float16)
|
||||
|
||||
pathlib.Path(tmp).unlink()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype")
|
||||
class TestSafetensors(unittest.TestCase):
|
||||
|
||||
@@ -79,10 +79,18 @@ class LazyBuffer:
|
||||
|
||||
def cast(self, dtype:DType, bitcast:bool=False):
|
||||
if self.dtype == dtype: return self
|
||||
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
||||
# TODO: applying this makes gpt2 slower
|
||||
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
|
||||
new_shape = self.shape
|
||||
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
||||
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
||||
if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
|
||||
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
|
||||
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
|
||||
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
|
||||
|
||||
def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST
|
||||
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST
|
||||
|
||||
@@ -14,7 +14,7 @@ inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:1].cast(dtypes.int64).item()
|
||||
json_len = t[0:8].bitcast(dtypes.int64).item()
|
||||
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
@@ -23,8 +23,8 @@ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
for k,v in metadata.items():
|
||||
if k == "__metadata__": continue
|
||||
dtype = safe_dtypes[v['dtype']]
|
||||
sz = (v['data_offsets'][1]-v['data_offsets'][0])//dtype.itemsize
|
||||
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].cast(dtype).reshape(v['shape'])
|
||||
sz = (v['data_offsets'][1]-v['data_offsets'][0])
|
||||
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
|
||||
return ret
|
||||
|
||||
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
||||
@@ -37,7 +37,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
|
||||
j += "\x20"*((8-len(j)%8)%8)
|
||||
pathlib.Path(fn).unlink(missing_ok=True)
|
||||
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
t[0:1].cast(dtypes.int64).assign([len(j)])
|
||||
t[0:8].bitcast(dtypes.int64).assign([len(j)])
|
||||
t[8:8+len(j)].assign(list(j.encode('utf-8')))
|
||||
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
||||
|
||||
@@ -83,7 +83,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
lens[storage[2]] = storage[4] * storage[1].itemsize
|
||||
if storage[2] not in offsets: return None
|
||||
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
||||
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
|
||||
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
|
||||
|
||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||
|
||||
@@ -53,7 +53,7 @@ class DiskRunner(JITRunner):
|
||||
# TODO: there shouldn't actually be casts here, bitcasts should fold into the load
|
||||
if ast.src[0].op == UnaryOps.CAST:
|
||||
top_src = ast.src[0].src[0]
|
||||
# TODO: assert that this is bitcast
|
||||
assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts"
|
||||
self.new_dtype = ast.src[0].arg[0]
|
||||
else:
|
||||
top_src = ast.src[0]
|
||||
|
||||
@@ -1033,7 +1033,7 @@ class Tensor:
|
||||
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
||||
def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else mlops.Cast.apply(self, dtype=dtype)
|
||||
def bitcast(self, dtype:DType) -> Tensor:
|
||||
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
|
||||
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
||||
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
||||
def float(self) -> Tensor: return self.cast(dtypes.float32)
|
||||
def half(self) -> Tensor: return self.cast(dtypes.float16)
|
||||
|
||||
Reference in New Issue
Block a user