tiny gguf_load cleanup [pr] (#8174)

round_up helper
This commit is contained in:
chenyu
2024-12-11 21:32:52 -05:00
committed by GitHub
parent 151ac5f5a2
commit 7047ffd27d

View File

@@ -2,7 +2,7 @@ import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable, TypeVar
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up
from tinygrad.shape.view import strides_for_shape
from tinygrad.multi import MultiLazyBuffer
@@ -301,15 +301,15 @@ def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
kv_data, state_dict = gguf_load(gguf_tensor)
```
"""
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1000_000), {}, {}
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
def read_str(): return str(reader.read(read_uint64()), "utf-8")
def read_arr():
reader, n = readers[read_int32()], read_uint64()
return [ reader() for _ in range(n) ]
readers: Dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t, f, nb in [ (0,"c",1),
(1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
readers: Dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t,f,nb in \
[ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
magic, version, n_tensors, n_kv = reader.read(4), read_int32(), read_int64(), read_int64()
@@ -320,7 +320,7 @@ def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
alignment, pos = kv_data.get("general.alignment", 32), reader.tell()
data_start = pos + (alignment - pos % alignment if pos % alignment != 0 else 0)
data_start = round_up(pos, alignment)
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))