mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
GGUF support (#7046)
* basic loader, untested * testing * remove utils import in test * q8_0 * q4_1 * end to end testing * minor cleanup * fix casting * moved to state * move tests * move dequant to fn * fix lint elif * remove gguf from extra * fix dict union * q6_k simpler * naming and spacing * gpt2-gguf example * cleanup * move gguf example * minor cleanup --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
import argparse
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.helpers import Timing, DEBUG, JIT, getenv, fetch, colored, trange
|
||||
from tinygrad.nn import Embedding, Linear, LayerNorm
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.nn.state import load_gguf, torch_load, load_state_dict, get_state_dict
|
||||
|
||||
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
|
||||
HALF = getenv("HALF")
|
||||
@@ -143,6 +145,34 @@ class GPT2:
|
||||
|
||||
return GPT2(model, tokenizer)
|
||||
|
||||
@staticmethod
|
||||
def build_gguf(model_size: str):
|
||||
q_type = model_size[len("gpt2_gguf_"):].upper()
|
||||
fn = fetch(f"https://huggingface.co/PrunaAI/gpt2-GGUF-smashed/resolve/main/gpt2.{q_type}.gguf?download=true")
|
||||
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
|
||||
kv_data, state_dict = load_gguf(gguf_tensor)
|
||||
|
||||
gpt2_params = {
|
||||
"dim": kv_data["gpt2.embedding_length"], "n_heads": kv_data["gpt2.attention.head_count"],
|
||||
"n_layers": kv_data["gpt2.block_count"], "norm_eps": kv_data["gpt2.attention.layer_norm_epsilon"],
|
||||
"vocab_size": 50257, "max_seq_len": kv_data["gpt2.context_length"],
|
||||
}
|
||||
def _remap_gguf_key(key: str):
|
||||
replaces = [
|
||||
("blk.", "h."), (".attn_qkv.bias", ".attn.c_attn.bias"), (".attn_qkv.weight", ".attn.c_attn.weight"),
|
||||
(".ffn_norm.bias", ".ln_2.bias"), (".ffn_norm.weight", ".ln_2.weight"), (".attn_norm.bias", ".ln_1.bias"),
|
||||
(".attn_norm.weight", ".ln_1.weight"), (".attn_output.bias", ".attn.c_proj.bias"), (".attn_output.weight", ".attn.c_proj.weight"),
|
||||
(".ffn_up.bias", ".mlp.c_fc.bias"), (".ffn_up.weight", ".mlp.c_fc.weight"), (".ffn_down.bias", ".mlp.c_proj.bias"),
|
||||
(".ffn_down.weight", ".mlp.c_proj.weight"), ("token_embd.weight", "wte.weight"), ("output.weight", "lm_head.weight"),
|
||||
("output_norm.bias", "ln_f.bias"), ("output_norm.weight", "ln_f.weight"), ("position_embd.weight", "wpe.weight"),
|
||||
]
|
||||
for ostr, ns in replaces: key = key.replace(ostr, ns)
|
||||
return key
|
||||
state_dict = { _remap_gguf_key(k): v for k, v in state_dict.items() }
|
||||
model = Transformer(**gpt2_params)
|
||||
load_state_dict(model, state_dict)
|
||||
return GPT2(model, tiktoken.get_encoding("gpt2"))
|
||||
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@@ -191,7 +221,7 @@ if __name__ == "__main__":
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print(f"using {args.model_size}")
|
||||
gpt2 = GPT2.build(args.model_size)
|
||||
gpt2 = GPT2.build_gguf(args.model_size) if args.model_size.startswith("gpt2_gguf_") else GPT2.build(args.model_size)
|
||||
|
||||
if args.benchmark != -1:
|
||||
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
|
||||
|
||||
1
setup.py
1
setup.py
@@ -57,6 +57,7 @@ setup(name='tinygrad',
|
||||
"hypothesis",
|
||||
"nibabel",
|
||||
"bottle",
|
||||
"ggml-python"
|
||||
],
|
||||
'docs': [
|
||||
"mkdocs",
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import pathlib, tempfile, unittest
|
||||
import tarfile
|
||||
import tarfile, ggml, ctypes
|
||||
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract, ggml_data_to_tensor, load_gguf
|
||||
from tinygrad.helpers import Timing, fetch, temp, CI
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
@@ -445,5 +445,95 @@ class TestPathTensor(unittest.TestCase):
|
||||
self.assertEqual(t_cpu.device, "CPU")
|
||||
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
ggml_test_block_count = 4
|
||||
ggml_type_to_np_dtype = {
|
||||
ggml.GGML_TYPE_F16: np.float16, ggml.GGML_TYPE_F32:np.float32, ggml.GGML_TYPE_F64:np.float64,
|
||||
ggml.GGML_TYPE_I8:np.int8, ggml.GGML_TYPE_I16: np.int16, ggml.GGML_TYPE_I32: np.int32, ggml.GGML_TYPE_I64: np.int64,
|
||||
}
|
||||
np_dtype_to_ctype = { np.float16: ctypes.c_uint16 }
|
||||
gguf_val_getters = [
|
||||
ggml.gguf_get_val_u8, ggml.gguf_get_val_i8, ggml.gguf_get_val_u16, ggml.gguf_get_val_i16,
|
||||
ggml.gguf_get_val_u32, ggml.gguf_get_val_i32, ggml.gguf_get_val_f32, ggml.gguf_get_val_bool,
|
||||
lambda *args: ggml.gguf_get_val_str(*args).decode("utf-8"), None,
|
||||
ggml.gguf_get_val_u64, ggml.gguf_get_val_i64, ggml.gguf_get_val_f64,
|
||||
]
|
||||
|
||||
def ggml_tensor_to_numpy(tensor: ggml.ggml_tensor_p):
|
||||
ctx: ggml.ggml_context_p | None = None
|
||||
ggml_type, n_dims, n_els = tensor.contents.type, ggml.ggml_n_dims(tensor), ggml.ggml_nelements(tensor)
|
||||
shape = tuple(reversed(tensor.contents.ne[:n_dims]))
|
||||
if ggml_type not in ggml_type_to_np_dtype:
|
||||
ctx = ggml.ggml_init(ggml.ggml_init_params(mem_size=n_els * 5 + 500, mem_buffer=None))
|
||||
ntensor = ggml.ggml_new_tensor(ctx, ggml.GGML_TYPE_F32, n_dims, tensor.contents.ne)
|
||||
type_traits = ggml.ggml_internal_get_type_traits(ggml_type)
|
||||
type_traits.to_float(ggml.ggml_get_data(tensor), ggml.ggml_get_data_f32(ntensor), n_els)
|
||||
tensor, ggml_type = ntensor, ggml.GGML_TYPE_F32
|
||||
|
||||
np_type = ggml_type_to_np_dtype[ggml_type]
|
||||
ctypes_type = np_dtype_to_ctype.get(np_type, None) or np.ctypeslib.as_ctypes_type(np_type)
|
||||
data = ggml.ggml_get_data(tensor)
|
||||
if data is None: raise ValueError("tensor data is None")
|
||||
arr = (ctypes_type * ggml.ggml_nelements(tensor)).from_address(data)
|
||||
strides = tuple(reversed(tensor.contents.nb[:n_dims]))
|
||||
output = np.ctypeslib.as_array(arr)
|
||||
output.dtype = np_type
|
||||
return np.lib.stride_tricks.as_strided(output, shape=shape, strides=strides), ctx
|
||||
|
||||
@unittest.skipIf(any(not is_dtype_supported(t) for t in [ dtypes.uint8, dtypes.half ]), "Backend must support uint8 and half")
|
||||
class TestGGUF(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
params = ggml.ggml_init_params(mem_size=0, mem_buffer=None, no_alloc=False)
|
||||
self.ctx = ctypes.cast(ggml.ggml_init(params), ctypes.POINTER(ctypes.c_void_p))
|
||||
def tearDown(self) -> None: ggml.ggml_free(self.ctx)
|
||||
|
||||
def test_load_tinyllama_q8_0(self): self._test_load_gguf("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q8_0.gguf?download=true")
|
||||
def test_load_tinyllama_q4_0(self): self._test_load_gguf("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf?download=true")
|
||||
def test_load_gpt2_q4_1(self): self._test_load_gguf("https://huggingface.co/PrunaAI/gpt2-GGUF-smashed/resolve/main/gpt2.Q4_1.gguf?download=true")
|
||||
def test_load_sample_q6_k(self): self._test_load_gguf("https://huggingface.co/Isotr0py/test-gguf-sample/resolve/main/Quant_Q6_K_1024.gguf?download=true")
|
||||
|
||||
def test_dequantization_q4_0(self): self._test_dequantization(ggml.GGML_TYPE_Q4_0)
|
||||
def test_dequantization_q4_1(self): self._test_dequantization(ggml.GGML_TYPE_Q4_1)
|
||||
def test_dequantization_q8_0(self): self._test_dequantization(ggml.GGML_TYPE_Q8_0)
|
||||
def test_dequantization_q6_k(self): self._test_dequantization(ggml.GGML_TYPE_Q6_K)
|
||||
|
||||
def _test_dequantization(self, ttype: int):
|
||||
type_traits = ggml.ggml_internal_get_type_traits(ttype)
|
||||
n_el, n_bytes = ggml_test_block_count * type_traits.blck_size, ggml_test_block_count * type_traits.type_size
|
||||
|
||||
data_in = (np.random.random((n_el,)).astype(np.float32) * 100 - 50).ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
|
||||
c_q_data, c_dq_data = (ctypes.c_char * n_bytes)(0), (ctypes.c_float * n_el)(0)
|
||||
type_traits.from_float(data_in, c_q_data, n_el)
|
||||
type_traits.to_float(c_q_data, c_dq_data, n_el)
|
||||
|
||||
q_tensor = Tensor(np.frombuffer(c_q_data, dtype=np.uint8, count=n_bytes))
|
||||
dq_tensor = ggml_data_to_tensor(q_tensor, n_el, ttype).reshape(n_el)
|
||||
|
||||
np.testing.assert_equal(dq_tensor.numpy(), np.frombuffer(c_dq_data, dtype=np.float32))
|
||||
def _test_load_gguf(self, url: str):
|
||||
fp = fetch(url)
|
||||
model_size = os.stat(fp).st_size
|
||||
gguf_tensor = Tensor.empty(model_size, dtype=dtypes.uint8, device=f"disk:{fp}").to(Device.DEFAULT)
|
||||
kv_data, tensors = load_gguf(gguf_tensor)
|
||||
|
||||
gguf_params = ggml.gguf_init_params(ctx=self.ctx, no_alloc=False)
|
||||
gguf_ctx = ggml.gguf_init_from_file(str(fp).encode("utf8"), gguf_params)
|
||||
param_ctx = gguf_params.ctx.contents.value
|
||||
|
||||
for ggml_tensor_idx in range(ggml.gguf_get_n_tensors(gguf_ctx)):
|
||||
tensor_name = ggml.gguf_get_tensor_name(gguf_ctx, ggml_tensor_idx)
|
||||
ggml_tensor = ggml.ggml_get_tensor(param_ctx, tensor_name)
|
||||
ggml_tensor_numpy, temp_ctx = ggml_tensor_to_numpy(ggml_tensor)
|
||||
tensor = tensors.get(tensor_name.decode("utf-8"))
|
||||
np.testing.assert_equal(tensor.numpy(), ggml_tensor_numpy)
|
||||
if temp_ctx is not None: ggml.ggml_free(temp_ctx)
|
||||
|
||||
for gguf_key_id in range(ggml.gguf_get_n_kv(gguf_ctx)):
|
||||
v = kv_data[ggml.gguf_get_key(gguf_ctx, gguf_key_id).decode("utf-8")]
|
||||
v_type = ggml.gguf_get_kv_type(gguf_ctx, gguf_key_id)
|
||||
if (get_fn := gguf_val_getters[v_type]) is not None: self.assertEqual(get_fn(gguf_ctx, gguf_key_id), v)
|
||||
|
||||
ggml.gguf_free(gguf_ctx)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os, json, pathlib, zipfile, pickle, tarfile, struct
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple
|
||||
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple, Callable
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
|
||||
@@ -226,3 +226,64 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
base_offset += 8 + lens[i]
|
||||
f.seek(rwd)
|
||||
return TorchPickle(f).load()
|
||||
|
||||
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int):
|
||||
bc_dtype = { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type, None)
|
||||
if bc_dtype is not None: return t[:bc_dtype.itemsize * n].bitcast(bc_dtype)
|
||||
|
||||
def q_to_uint8(t: Tensor, b: int) -> Tensor:
|
||||
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
|
||||
return t.unsqueeze(-1).expand((*t.shape,8//b)).div(shift_tensor, upcast=False).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
|
||||
|
||||
blk_nel, blk_nb = { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }[ggml_type]
|
||||
blocks = t[:(n//blk_nel)*blk_nb].reshape((-1, blk_nb))
|
||||
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
|
||||
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
|
||||
if ggml_type == 3:
|
||||
d, m = tuple(blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
|
||||
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
|
||||
if ggml_type == 14:
|
||||
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
|
||||
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((blocks.shape[0], 16, 16)).reshape((-1, 256))
|
||||
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
|
||||
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
|
||||
|
||||
def load_gguf(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
|
||||
"""
|
||||
Loads a gguf file from a tensor.
|
||||
|
||||
```python
|
||||
fn = "Meta-Llama-3-8B-Instruct.Q4_0.gguf"
|
||||
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
|
||||
kv_data, state_dict = load_gguf(gguf_tensor)
|
||||
```
|
||||
"""
|
||||
if tensor.dtype != dtypes.uint8 or len(tensor.shape) != 1: raise ValueError("GGUF tensor must be 1d and of dtype uint8!")
|
||||
pos, read_buffer, rb_start, kv_data, state_dict = 0, memoryview(bytes()), 0, {}, {}
|
||||
def read_bytes(n: int):
|
||||
nonlocal pos, read_buffer, rb_start
|
||||
if rb_start + len(read_buffer) < pos + n: rb_start, read_buffer = pos, tensor[pos:(pos+max(n, 1000_000))].data()
|
||||
return read_buffer[pos-rb_start:(pos:=pos+n)-rb_start]
|
||||
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, read_bytes(n))[0]
|
||||
def read_str(): return str(read_bytes(read_uint64()), "utf-8")
|
||||
def read_arr():
|
||||
reader, n = readers[read_int32()], read_uint64()
|
||||
return [ reader() for _ in range(n) ]
|
||||
|
||||
readers: Dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t, f, nb in [ (0,"c",1),
|
||||
(1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
|
||||
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
|
||||
|
||||
magic, version, n_tensors, n_kv = read_bytes(4), read_int32(), read_int64(), read_int64()
|
||||
if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
|
||||
for _ in range(n_kv):
|
||||
k, typ = read_str(), read_int32()
|
||||
kv_data[k] = readers[typ]()
|
||||
|
||||
t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
|
||||
alignment = kv_data.get("general.alignment", 32)
|
||||
data_start = pos = pos + (alignment - pos % alignment if pos % alignment != 0 else 0)
|
||||
|
||||
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
|
||||
|
||||
return kv_data, state_dict
|
||||
|
||||
Reference in New Issue
Block a user