diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index e0e2102852..f07004470b 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -2,7 +2,7 @@ import json, pathlib, zipfile, pickle, tarfile, struct, functools, io from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable, TypeVar from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes -from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm +from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up from tinygrad.shape.view import strides_for_shape from tinygrad.multi import MultiLazyBuffer @@ -301,15 +301,15 @@ def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]: kv_data, state_dict = gguf_load(gguf_tensor) ``` """ - reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1000_000), {}, {} + reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {} def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0] def read_str(): return str(reader.read(read_uint64()), "utf-8") def read_arr(): reader, n = readers[read_int32()], read_uint64() return [ reader() for _ in range(n) ] - readers: Dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t, f, nb in [ (0,"c",1), - (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } } + readers: Dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t,f,nb in \ + [ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } } read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11] magic, version, n_tensors, n_kv = reader.read(4), read_int32(), read_int64(), read_int64() @@ -320,7 +320,7 @@ def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]: t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ] alignment, pos = kv_data.get("general.alignment", 32), reader.tell() - data_start = pos + (alignment - pos % alignment if pos % alignment != 0 else 0) + data_start = round_up(pos, alignment) for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))