Shape changing bitcast and assert bitcast in disk (#3973)

* Shape changing bitcast

* only support it on disk

* basic test

* more tests

* RuntimeError instead of assert

* create unique temp files

* move tests that use disk to test_disk_tensor

* linter

* remove assert on error messages

* that's RuntimeError now

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
uuuvn
2024-03-29 06:49:10 +02:00
committed by GitHub
parent 793ab0512e
commit 8a40d7d423
6 changed files with 56 additions and 11 deletions

View File

@@ -249,7 +249,7 @@ class TestUint8Dtype(TestDType):
@unittest.skipIf(Device.DEFAULT == "WEBGL", "No bitcast on WebGL")
class TestBitCast(unittest.TestCase):
def test_shape_change_bitcast(self):
with self.assertRaises(AssertionError):
with self.assertRaises(RuntimeError):
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
def test_bitcast_float_to_int32(self):

View File

@@ -1,6 +1,7 @@
import pathlib, unittest
import pathlib, tempfile, unittest
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, fetch, temp
from test.helpers import is_dtype_supported
@@ -36,14 +37,50 @@ test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.
#test_size = test_fn.stat().st_size
test_size = 1024*1024*1024*2
def _test_bitcasted(t: Tensor, dt: DType, expected):
np.testing.assert_allclose(t.bitcast(dt).numpy(), expected)
# sudo su -c 'sync; echo 1 > /proc/sys/vm/drop_caches' && python3 test/unit/test_disk_tensor.py TestRawDiskBuffer.test_readinto_read_speed
@unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests")
class TestRawDiskBuffer(unittest.TestCase):
@unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests")
def test_readinto_read_speed(self):
tst = np.empty(test_size, np.uint8)
with open(test_fn, "rb") as f:
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
f.readinto(tst)
def test_bitcasts_on_disk(self):
tmp = tempfile.mktemp()
# ground truth = https://evanw.github.io/float-toy/
t = Tensor.empty((128, 128), dtype=dtypes.uint8, device=f"disk:{tmp}") # uint8
# all zeroes
_test_bitcasted(t, dtypes.float16, 0.0)
_test_bitcasted(t, dtypes.uint16, 0)
_test_bitcasted(t, dtypes.float32, 0.0)
_test_bitcasted(t, dtypes.uint32, 0)
# pi in float16 stored via int16
t.bitcast(dtypes.uint16).assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16)).realize()
_test_bitcasted(t, dtypes.float16, 3.141)
_test_bitcasted(t, dtypes.float32, 50.064727)
_test_bitcasted(t, dtypes.uint16, 0x4248)
_test_bitcasted(t, dtypes.uint32, 0x42484248)
# pi in float32 stored via float32
t.bitcast(dtypes.float32).assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32)).realize()
_test_bitcasted(t, dtypes.float32, 3.1415927)
_test_bitcasted(t, dtypes.uint32, 0x40490FDB)
# doesn't suport normal cast
with self.assertRaises(RuntimeError):
Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16)
# Those two should be moved to test_dtype.py:test_shape_change_bitcast after bitcast works on non-disk
with self.assertRaises(RuntimeError):
# should fail because 3 int8 is 3 bytes but float16 is two and 3 isn't a multiple of 2
Tensor.empty((3,), dtype=dtypes.int8, device=f"DISK:{tmp}").bitcast(dtypes.float16)
with self.assertRaises(RuntimeError):
# should fail because backprop through bitcast is undefined
Tensor.empty((4,), dtype=dtypes.int8, requires_grad=True, device=f"DISK:{tmp}").bitcast(dtypes.float16)
pathlib.Path(tmp).unlink()
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype")
class TestSafetensors(unittest.TestCase):

View File

@@ -79,10 +79,18 @@ class LazyBuffer:
def cast(self, dtype:DType, bitcast:bool=False):
if self.dtype == dtype: return self
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
# TODO: applying this makes gpt2 slower
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
return self.base.cast(dtype, bitcast)._view(self.st)
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
new_shape = self.shape
if bitcast and self.dtype.itemsize != dtype.itemsize:
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST

View File

@@ -14,7 +14,7 @@ inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:1].cast(dtypes.int64).item()
json_len = t[0:8].bitcast(dtypes.int64).item()
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
@@ -23,8 +23,8 @@ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
for k,v in metadata.items():
if k == "__metadata__": continue
dtype = safe_dtypes[v['dtype']]
sz = (v['data_offsets'][1]-v['data_offsets'][0])//dtype.itemsize
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].cast(dtype).reshape(v['shape'])
sz = (v['data_offsets'][1]-v['data_offsets'][0])
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
return ret
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
@@ -37,7 +37,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
j += "\x20"*((8-len(j)%8)%8)
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:1].cast(dtypes.int64).assign([len(j)])
t[0:8].bitcast(dtypes.int64).assign([len(j)])
t[8:8+len(j)].assign(list(j.encode('utf-8')))
for k,v in safe_load(t).items(): v.assign(tensors[k])
@@ -83,7 +83,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]

View File

@@ -53,7 +53,7 @@ class DiskRunner(JITRunner):
# TODO: there shouldn't actually be casts here, bitcasts should fold into the load
if ast.src[0].op == UnaryOps.CAST:
top_src = ast.src[0].src[0]
# TODO: assert that this is bitcast
assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts"
self.new_dtype = ast.src[0].arg[0]
else:
top_src = ast.src[0]

View File

@@ -1033,7 +1033,7 @@ class Tensor:
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else mlops.Cast.apply(self, dtype=dtype)
def bitcast(self, dtype:DType) -> Tensor:
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
def float(self) -> Tensor: return self.cast(dtypes.float32)
def half(self) -> Tensor: return self.cast(dtypes.float16)