safetensors! (#903)

* safetensors test

* safe_save

* load back with real safetensors

* bugfix in device name. add simple torch_load

* it works for llama, but it's slower...

* mmap

* no intermediate

* load mmaped

* readinto speed

* not ready yet

* revert that
This commit is contained in:
George Hotz
2023-06-02 13:41:09 -07:00
committed by GitHub
parent 513aeb2f66
commit d58586bb17
7 changed files with 146 additions and 14 deletions

View File

@@ -263,6 +263,18 @@ if __name__ == "__main__":
del weights
# disktensor loader isn't fast yet
"""
from tinygrad.state import torch_load, get_state_dict
state_dict = torch_load(WEIGHTS_7B_FILENAME)
model = Transformer(**args_7B)
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
for k,v in (t := tqdm(get_state_dict(model).items())):
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k}")
if k not in state_dict: continue
v.assign(state_dict[k].to(v.device)).realize()
"""
# *** prompt engineers work here ****
if args.personality.lower() == "stacy":

View File

@@ -40,6 +40,7 @@ setup(name='tinygrad',
"onnx2torch",
"opencv-python",
"tabulate",
"safetensors",
],
},
include_package_data=True)

View File

@@ -2,6 +2,68 @@ import pathlib
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.state import safe_load, safe_save, get_state_dict
from tinygrad.helpers import dtypes
from tinygrad.runtime.ops_disk import RawDiskBuffer
from extra.helpers import Timing
test_fn = pathlib.Path(__file__).parent.parent.parent / "weights/LLaMA/7B/consolidated.00.pth"
#test_size = test_fn.stat().st_size
test_size = 1024*1024*1024*2
# sudo su -c 'sync; echo 1 > /proc/sys/vm/drop_caches' && python3 test/unit/test_disk_tensor.py TestRawDiskBuffer.test_readinto_read_speed
@unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests")
class TestRawDiskBuffer(unittest.TestCase):
def test_readinto_read_speed(self):
tst = np.empty(test_size, np.uint8)
with open(test_fn, "rb") as f:
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
f.readinto(tst)
def test_mmap_read_speed(self):
db = RawDiskBuffer(test_size, dtype=dtypes.uint8, device=test_fn)
tst = np.empty(test_size, np.uint8)
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
np.copyto(tst, db.toCPU())
class TestSafetensors(unittest.TestCase):
def test_real_safetensors(self):
import torch
from safetensors.torch import save_file
torch.manual_seed(1337)
tensors = {
"weight1": torch.randn((16, 16)),
"weight2": torch.arange(0, 17, dtype=torch.uint8),
"weight3": torch.arange(0, 17, dtype=torch.int32).reshape(17,1,1),
"weight4": torch.arange(0, 2, dtype=torch.uint8),
}
save_file(tensors, "/tmp/model.safetensors")
ret = safe_load("/tmp/model.safetensors")
for k,v in tensors.items(): np.testing.assert_array_equal(ret[k].numpy(), v.numpy())
safe_save(ret, "/tmp/model.safetensors_alt")
with open("/tmp/model.safetensors", "rb") as f:
with open("/tmp/model.safetensors_alt", "rb") as g:
assert f.read() == g.read()
ret2 = safe_load("/tmp/model.safetensors_alt")
for k,v in tensors.items(): np.testing.assert_array_equal(ret2[k].numpy(), v.numpy())
def test_efficientnet_safetensors(self):
from models.efficientnet import EfficientNet
model = EfficientNet(0)
state_dict = get_state_dict(model)
safe_save(state_dict, "/tmp/eff0")
state_dict_loaded = safe_load("/tmp/eff0")
assert sorted(list(state_dict_loaded.keys())) == sorted(list(state_dict.keys()))
for k,v in state_dict.items():
np.testing.assert_array_equal(v.numpy(), state_dict_loaded[k].numpy())
# load with the real safetensors
from safetensors import safe_open
with safe_open("/tmp/eff0", framework="pt", device="cpu") as f:
assert sorted(list(f.keys())) == sorted(list(state_dict.keys()))
for k in f.keys():
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
class TestDiskTensor(unittest.TestCase):
def test_empty(self):
@@ -46,6 +108,7 @@ class TestDiskTensor(unittest.TestCase):
def test_assign_slice(self):
pathlib.Path("/tmp/dt4").unlink(missing_ok=True)
cc = Tensor.arange(10, device="CPU").to("disk:/tmp/dt4").realize()
#cc.assign(np.ones(10)).realize()
print(cc[3:5].numpy())
cc[3:5].assign([13, 12]).realize()

View File

@@ -297,6 +297,7 @@ class _Device:
def __init__(self) -> None:
self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, self._default_device())
def canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "")
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return self._get_device(x.split(":")[0].upper())
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def _get_device(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]

View File

@@ -1,25 +1,31 @@
import os, mmap
from typing import Optional
from typing import Callable, Dict
from tinygrad.helpers import prod
from tinygrad.helpers import prod, DType
from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps
class RawDiskBuffer(RawBufferMapped):
def __init__(self, size, dtype, device:Optional[str]=None, buf=None, shape=None):
def __init__(self, size, dtype:DType, device:Optional[str]=None, buf=None, shape=None, offset=0): # pylint: disable=super-init-not-called
self.shape = (size, ) if shape is None else shape
self.offset = offset # this is an offset in bytes
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
if device is not None:
with open(device, "a+b") as f:
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
buf = memoryview(mmap.mmap(f.fileno(), size * dtype.itemsize))
super().__init__(size, dtype, buf)
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buffer(), shape=arg)
buf = mmap.mmap(f.fileno(), size * dtype.itemsize)
buf.madvise(mmap.MADV_SEQUENTIAL)
# NOTE: we don't call super since disk tensors don't use RAM
self.size, self.dtype, self._buf = size, dtype, buf
def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset)
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
def shrink(self, arg):
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
return RawDiskBuffer(arg[0][1]-arg[0][0], self.dtype, buf=self._buffer()[arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize:arg[0][1]*prod(self.shape[1:])*self.dtype.itemsize])
def _buffer(self): return self._buf
offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize
size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:])
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:])
def _buffer(self): return memoryview(self._buf)[self.offset:self.offset+self.size*self.dtype.itemsize]
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, MovementOps.RESHAPE: RawDiskBuffer.reshape, MovementOps.SHRINK: RawDiskBuffer.shrink }
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.RESHAPE: RawDiskBuffer.reshape, MovementOps.SHRINK: RawDiskBuffer.shrink }
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, to_underlying=lambda x:x, from_underlying=lambda x:x)

50
tinygrad/state.py Normal file
View File

@@ -0,0 +1,50 @@
import os, json, pathlib
from typing import Dict, Union
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, prod
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def torch_load(fn:str):
import zipfile, pickle
myzip = zipfile.ZipFile(fn, 'r')
base_name = myzip.namelist()[0].split('/', 1)[0]
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
with myzip.open(f'{base_name}/data/{storage[2]}') as myfile:
offset = myfile._orig_compress_start # type: ignore
return t[offset:offset+prod(size)].cast(storage[1]).reshape(size)
intercept = {"HalfStorage": dtypes.float16, "_rebuild_tensor_v2": _rebuild_tensor_v2}
class TorchPickle(pickle.Unpickler):
def find_class(self, module, name):
if module.startswith("torch"): return intercept[name]
return super().find_class(module, name)
def persistent_load(self, pid): return pid
with myzip.open(f'{base_name}/data.pkl') as myfile: return TorchPickle(myfile).load()
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
metadata = json.loads(t[8:8+json_len].numpy().tobytes())
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
def safe_save(tensors:Dict[str, Tensor], fn:str):
metadata, offset = {}, 0
for k,v in tensors.items():
metadata[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
offset += v.nbytes()
j = json.dumps(metadata, separators=(',', ':'))
j += "\x20"*((8-len(j)%8)%8)
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:1].cast(dtypes.int64).assign([len(j)])
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8))
for k,v in safe_load(t).items(): v.assign(tensors[k])
# TODO: move get_state_dict and get_parameters here
from tinygrad.nn.optim import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401

View File

@@ -34,7 +34,7 @@ class Tensor:
default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[int, float, list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
device = (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # canonicalize device
device = Device.canonicalize(device)
if isinstance(data, (int, float, list)):
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
elif isinstance(data, LazyBuffer) and data.device != device:
@@ -109,13 +109,11 @@ class Tensor:
def to_(self, device:str):
assert self.lazydata.realized is None
self.lazydata.device = device
if self.grad:
self.grad.lazydata.device = device
if self.grad: self.grad.to_(device)
def to(self, device:str):
ret = Tensor(self.lazydata, device)
if self.grad:
ret.grad = self.grad.to(device)
if self.grad: ret.grad = self.grad.to(device)
return ret
# ***** creation helper functions *****
@@ -142,7 +140,7 @@ class Tensor:
@staticmethod
def empty(*shape, device=Device.DEFAULT, dtype:Optional[DType]=None, **kwargs):
# NOTE: we do the reshape to fix interpreted buffers
return Tensor(LazyBuffer.empty([prod(shape)], Tensor.default_type if dtype is None else dtype, device), dtype=dtype, device=device, **kwargs).reshape(*shape)
return Tensor(LazyBuffer.empty([prod(shape)], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device)), dtype=dtype, device=device, **kwargs).reshape(*shape)
@staticmethod
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
@@ -584,6 +582,7 @@ class Tensor:
def ndim(self) -> int: return len(self.shape)
def numel(self) -> int: return math.prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
# register functions to move between devices