mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
safetensors! (#903)
* safetensors test * safe_save * load back with real safetensors * bugfix in device name. add simple torch_load * it works for llama, but it's slower... * mmap * no intermediate * load mmaped * readinto speed * not ready yet * revert that
This commit is contained in:
@@ -263,6 +263,18 @@ if __name__ == "__main__":
|
||||
|
||||
del weights
|
||||
|
||||
# disktensor loader isn't fast yet
|
||||
"""
|
||||
from tinygrad.state import torch_load, get_state_dict
|
||||
state_dict = torch_load(WEIGHTS_7B_FILENAME)
|
||||
model = Transformer(**args_7B)
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
|
||||
for k,v in (t := tqdm(get_state_dict(model).items())):
|
||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k}")
|
||||
if k not in state_dict: continue
|
||||
v.assign(state_dict[k].to(v.device)).realize()
|
||||
"""
|
||||
|
||||
# *** prompt engineers work here ****
|
||||
|
||||
if args.personality.lower() == "stacy":
|
||||
|
||||
1
setup.py
1
setup.py
@@ -40,6 +40,7 @@ setup(name='tinygrad',
|
||||
"onnx2torch",
|
||||
"opencv-python",
|
||||
"tabulate",
|
||||
"safetensors",
|
||||
],
|
||||
},
|
||||
include_package_data=True)
|
||||
|
||||
@@ -2,6 +2,68 @@ import pathlib
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import safe_load, safe_save, get_state_dict
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from extra.helpers import Timing
|
||||
|
||||
test_fn = pathlib.Path(__file__).parent.parent.parent / "weights/LLaMA/7B/consolidated.00.pth"
|
||||
#test_size = test_fn.stat().st_size
|
||||
test_size = 1024*1024*1024*2
|
||||
|
||||
# sudo su -c 'sync; echo 1 > /proc/sys/vm/drop_caches' && python3 test/unit/test_disk_tensor.py TestRawDiskBuffer.test_readinto_read_speed
|
||||
@unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests")
|
||||
class TestRawDiskBuffer(unittest.TestCase):
|
||||
def test_readinto_read_speed(self):
|
||||
tst = np.empty(test_size, np.uint8)
|
||||
with open(test_fn, "rb") as f:
|
||||
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
|
||||
f.readinto(tst)
|
||||
|
||||
def test_mmap_read_speed(self):
|
||||
db = RawDiskBuffer(test_size, dtype=dtypes.uint8, device=test_fn)
|
||||
tst = np.empty(test_size, np.uint8)
|
||||
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
|
||||
np.copyto(tst, db.toCPU())
|
||||
|
||||
class TestSafetensors(unittest.TestCase):
|
||||
def test_real_safetensors(self):
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
torch.manual_seed(1337)
|
||||
tensors = {
|
||||
"weight1": torch.randn((16, 16)),
|
||||
"weight2": torch.arange(0, 17, dtype=torch.uint8),
|
||||
"weight3": torch.arange(0, 17, dtype=torch.int32).reshape(17,1,1),
|
||||
"weight4": torch.arange(0, 2, dtype=torch.uint8),
|
||||
}
|
||||
save_file(tensors, "/tmp/model.safetensors")
|
||||
|
||||
ret = safe_load("/tmp/model.safetensors")
|
||||
for k,v in tensors.items(): np.testing.assert_array_equal(ret[k].numpy(), v.numpy())
|
||||
safe_save(ret, "/tmp/model.safetensors_alt")
|
||||
with open("/tmp/model.safetensors", "rb") as f:
|
||||
with open("/tmp/model.safetensors_alt", "rb") as g:
|
||||
assert f.read() == g.read()
|
||||
ret2 = safe_load("/tmp/model.safetensors_alt")
|
||||
for k,v in tensors.items(): np.testing.assert_array_equal(ret2[k].numpy(), v.numpy())
|
||||
|
||||
def test_efficientnet_safetensors(self):
|
||||
from models.efficientnet import EfficientNet
|
||||
model = EfficientNet(0)
|
||||
state_dict = get_state_dict(model)
|
||||
safe_save(state_dict, "/tmp/eff0")
|
||||
state_dict_loaded = safe_load("/tmp/eff0")
|
||||
assert sorted(list(state_dict_loaded.keys())) == sorted(list(state_dict.keys()))
|
||||
for k,v in state_dict.items():
|
||||
np.testing.assert_array_equal(v.numpy(), state_dict_loaded[k].numpy())
|
||||
|
||||
# load with the real safetensors
|
||||
from safetensors import safe_open
|
||||
with safe_open("/tmp/eff0", framework="pt", device="cpu") as f:
|
||||
assert sorted(list(f.keys())) == sorted(list(state_dict.keys()))
|
||||
for k in f.keys():
|
||||
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
|
||||
|
||||
class TestDiskTensor(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
@@ -46,6 +108,7 @@ class TestDiskTensor(unittest.TestCase):
|
||||
def test_assign_slice(self):
|
||||
pathlib.Path("/tmp/dt4").unlink(missing_ok=True)
|
||||
cc = Tensor.arange(10, device="CPU").to("disk:/tmp/dt4").realize()
|
||||
|
||||
#cc.assign(np.ones(10)).realize()
|
||||
print(cc[3:5].numpy())
|
||||
cc[3:5].assign([13, 12]).realize()
|
||||
|
||||
@@ -297,6 +297,7 @@ class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, self._default_device())
|
||||
def canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "")
|
||||
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return self._get_device(x.split(":")[0].upper())
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def _get_device(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
|
||||
|
||||
@@ -1,25 +1,31 @@
|
||||
import os, mmap
|
||||
from typing import Optional
|
||||
from typing import Callable, Dict
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.helpers import prod, DType
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps
|
||||
|
||||
class RawDiskBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype, device:Optional[str]=None, buf=None, shape=None):
|
||||
def __init__(self, size, dtype:DType, device:Optional[str]=None, buf=None, shape=None, offset=0): # pylint: disable=super-init-not-called
|
||||
self.shape = (size, ) if shape is None else shape
|
||||
self.offset = offset # this is an offset in bytes
|
||||
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
|
||||
if device is not None:
|
||||
with open(device, "a+b") as f:
|
||||
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
|
||||
buf = memoryview(mmap.mmap(f.fileno(), size * dtype.itemsize))
|
||||
super().__init__(size, dtype, buf)
|
||||
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buffer(), shape=arg)
|
||||
buf = mmap.mmap(f.fileno(), size * dtype.itemsize)
|
||||
buf.madvise(mmap.MADV_SEQUENTIAL)
|
||||
# NOTE: we don't call super since disk tensors don't use RAM
|
||||
self.size, self.dtype, self._buf = size, dtype, buf
|
||||
def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset)
|
||||
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
|
||||
def shrink(self, arg):
|
||||
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
|
||||
return RawDiskBuffer(arg[0][1]-arg[0][0], self.dtype, buf=self._buffer()[arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize:arg[0][1]*prod(self.shape[1:])*self.dtype.itemsize])
|
||||
def _buffer(self): return self._buf
|
||||
offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize
|
||||
size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:])
|
||||
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:])
|
||||
def _buffer(self): return memoryview(self._buf)[self.offset:self.offset+self.size*self.dtype.itemsize]
|
||||
|
||||
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, MovementOps.RESHAPE: RawDiskBuffer.reshape, MovementOps.SHRINK: RawDiskBuffer.shrink }
|
||||
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.RESHAPE: RawDiskBuffer.reshape, MovementOps.SHRINK: RawDiskBuffer.shrink }
|
||||
|
||||
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, to_underlying=lambda x:x, from_underlying=lambda x:x)
|
||||
50
tinygrad/state.py
Normal file
50
tinygrad/state.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os, json, pathlib
|
||||
from typing import Dict, Union
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
|
||||
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def torch_load(fn:str):
|
||||
import zipfile, pickle
|
||||
myzip = zipfile.ZipFile(fn, 'r')
|
||||
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
||||
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||||
with myzip.open(f'{base_name}/data/{storage[2]}') as myfile:
|
||||
offset = myfile._orig_compress_start # type: ignore
|
||||
return t[offset:offset+prod(size)].cast(storage[1]).reshape(size)
|
||||
|
||||
intercept = {"HalfStorage": dtypes.float16, "_rebuild_tensor_v2": _rebuild_tensor_v2}
|
||||
class TorchPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
if module.startswith("torch"): return intercept[name]
|
||||
return super().find_class(module, name)
|
||||
def persistent_load(self, pid): return pid
|
||||
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile: return TorchPickle(myfile).load()
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
|
||||
metadata = json.loads(t[8:8+json_len].numpy().tobytes())
|
||||
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
|
||||
|
||||
def safe_save(tensors:Dict[str, Tensor], fn:str):
|
||||
metadata, offset = {}, 0
|
||||
for k,v in tensors.items():
|
||||
metadata[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
|
||||
offset += v.nbytes()
|
||||
j = json.dumps(metadata, separators=(',', ':'))
|
||||
j += "\x20"*((8-len(j)%8)%8)
|
||||
pathlib.Path(fn).unlink(missing_ok=True)
|
||||
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
t[0:1].cast(dtypes.int64).assign([len(j)])
|
||||
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8))
|
||||
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
||||
|
||||
# TODO: move get_state_dict and get_parameters here
|
||||
from tinygrad.nn.optim import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401
|
||||
@@ -34,7 +34,7 @@ class Tensor:
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
|
||||
def __init__(self, data:Union[int, float, list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
device = (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # canonicalize device
|
||||
device = Device.canonicalize(device)
|
||||
if isinstance(data, (int, float, list)):
|
||||
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
|
||||
elif isinstance(data, LazyBuffer) and data.device != device:
|
||||
@@ -109,13 +109,11 @@ class Tensor:
|
||||
def to_(self, device:str):
|
||||
assert self.lazydata.realized is None
|
||||
self.lazydata.device = device
|
||||
if self.grad:
|
||||
self.grad.lazydata.device = device
|
||||
if self.grad: self.grad.to_(device)
|
||||
|
||||
def to(self, device:str):
|
||||
ret = Tensor(self.lazydata, device)
|
||||
if self.grad:
|
||||
ret.grad = self.grad.to(device)
|
||||
if self.grad: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
# ***** creation helper functions *****
|
||||
@@ -142,7 +140,7 @@ class Tensor:
|
||||
@staticmethod
|
||||
def empty(*shape, device=Device.DEFAULT, dtype:Optional[DType]=None, **kwargs):
|
||||
# NOTE: we do the reshape to fix interpreted buffers
|
||||
return Tensor(LazyBuffer.empty([prod(shape)], Tensor.default_type if dtype is None else dtype, device), dtype=dtype, device=device, **kwargs).reshape(*shape)
|
||||
return Tensor(LazyBuffer.empty([prod(shape)], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device)), dtype=dtype, device=device, **kwargs).reshape(*shape)
|
||||
|
||||
@staticmethod
|
||||
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
|
||||
@@ -584,6 +582,7 @@ class Tensor:
|
||||
def ndim(self) -> int: return len(self.shape)
|
||||
def numel(self) -> int: return math.prod(self.shape)
|
||||
def element_size(self) -> int: return self.dtype.itemsize
|
||||
def nbytes(self) -> int: return self.numel() * self.element_size()
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
|
||||
# register functions to move between devices
|
||||
|
||||
Reference in New Issue
Block a user