GGUF support (#7046)

* basic loader, untested

* testing

* remove utils import in test

* q8_0

* q4_1

* end to end testing

* minor cleanup

* fix casting

* moved to state

* move tests

* move dequant to fn

* fix lint elif

* remove gguf from extra

* fix dict union

* q6_k simpler

* naming and spacing

* gpt2-gguf example

* cleanup

* move gguf example

* minor cleanup

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
leopf
2024-10-21 10:15:34 +02:00
committed by GitHub
parent 17e7d8f10e
commit b6d9b276bb
4 changed files with 188 additions and 6 deletions

View File

@@ -1,13 +1,15 @@
#!/usr/bin/env python3
import os
from typing import Optional, Union
import argparse
import numpy as np
import tiktoken
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp
from tinygrad.helpers import Timing, DEBUG, JIT, getenv, fetch, colored, trange
from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.nn.state import load_gguf, torch_load, load_state_dict, get_state_dict
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
HALF = getenv("HALF")
@@ -143,6 +145,34 @@ class GPT2:
return GPT2(model, tokenizer)
@staticmethod
def build_gguf(model_size: str):
q_type = model_size[len("gpt2_gguf_"):].upper()
fn = fetch(f"https://huggingface.co/PrunaAI/gpt2-GGUF-smashed/resolve/main/gpt2.{q_type}.gguf?download=true")
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
kv_data, state_dict = load_gguf(gguf_tensor)
gpt2_params = {
"dim": kv_data["gpt2.embedding_length"], "n_heads": kv_data["gpt2.attention.head_count"],
"n_layers": kv_data["gpt2.block_count"], "norm_eps": kv_data["gpt2.attention.layer_norm_epsilon"],
"vocab_size": 50257, "max_seq_len": kv_data["gpt2.context_length"],
}
def _remap_gguf_key(key: str):
replaces = [
("blk.", "h."), (".attn_qkv.bias", ".attn.c_attn.bias"), (".attn_qkv.weight", ".attn.c_attn.weight"),
(".ffn_norm.bias", ".ln_2.bias"), (".ffn_norm.weight", ".ln_2.weight"), (".attn_norm.bias", ".ln_1.bias"),
(".attn_norm.weight", ".ln_1.weight"), (".attn_output.bias", ".attn.c_proj.bias"), (".attn_output.weight", ".attn.c_proj.weight"),
(".ffn_up.bias", ".mlp.c_fc.bias"), (".ffn_up.weight", ".mlp.c_fc.weight"), (".ffn_down.bias", ".mlp.c_proj.bias"),
(".ffn_down.weight", ".mlp.c_proj.weight"), ("token_embd.weight", "wte.weight"), ("output.weight", "lm_head.weight"),
("output_norm.bias", "ln_f.bias"), ("output_norm.weight", "ln_f.weight"), ("position_embd.weight", "wpe.weight"),
]
for ostr, ns in replaces: key = key.replace(ostr, ns)
return key
state_dict = { _remap_gguf_key(k): v for k, v in state_dict.items() }
model = Transformer(**gpt2_params)
load_state_dict(model, state_dict)
return GPT2(model, tiktoken.get_encoding("gpt2"))
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
@@ -191,7 +221,7 @@ if __name__ == "__main__":
np.random.seed(args.seed)
print(f"using {args.model_size}")
gpt2 = GPT2.build(args.model_size)
gpt2 = GPT2.build_gguf(args.model_size) if args.model_size.startswith("gpt2_gguf_") else GPT2.build(args.model_size)
if args.benchmark != -1:
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()

View File

@@ -57,6 +57,7 @@ setup(name='tinygrad',
"hypothesis",
"nibabel",
"bottle",
"ggml-python"
],
'docs': [
"mkdocs",

View File

@@ -1,11 +1,11 @@
import os
import pathlib, tempfile, unittest
import tarfile
import tarfile, ggml, ctypes
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract, ggml_data_to_tensor, load_gguf
from tinygrad.helpers import Timing, fetch, temp, CI
from test.helpers import is_dtype_supported
@@ -445,5 +445,95 @@ class TestPathTensor(unittest.TestCase):
self.assertEqual(t_cpu.device, "CPU")
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
ggml_test_block_count = 4
ggml_type_to_np_dtype = {
ggml.GGML_TYPE_F16: np.float16, ggml.GGML_TYPE_F32:np.float32, ggml.GGML_TYPE_F64:np.float64,
ggml.GGML_TYPE_I8:np.int8, ggml.GGML_TYPE_I16: np.int16, ggml.GGML_TYPE_I32: np.int32, ggml.GGML_TYPE_I64: np.int64,
}
np_dtype_to_ctype = { np.float16: ctypes.c_uint16 }
gguf_val_getters = [
ggml.gguf_get_val_u8, ggml.gguf_get_val_i8, ggml.gguf_get_val_u16, ggml.gguf_get_val_i16,
ggml.gguf_get_val_u32, ggml.gguf_get_val_i32, ggml.gguf_get_val_f32, ggml.gguf_get_val_bool,
lambda *args: ggml.gguf_get_val_str(*args).decode("utf-8"), None,
ggml.gguf_get_val_u64, ggml.gguf_get_val_i64, ggml.gguf_get_val_f64,
]
def ggml_tensor_to_numpy(tensor: ggml.ggml_tensor_p):
ctx: ggml.ggml_context_p | None = None
ggml_type, n_dims, n_els = tensor.contents.type, ggml.ggml_n_dims(tensor), ggml.ggml_nelements(tensor)
shape = tuple(reversed(tensor.contents.ne[:n_dims]))
if ggml_type not in ggml_type_to_np_dtype:
ctx = ggml.ggml_init(ggml.ggml_init_params(mem_size=n_els * 5 + 500, mem_buffer=None))
ntensor = ggml.ggml_new_tensor(ctx, ggml.GGML_TYPE_F32, n_dims, tensor.contents.ne)
type_traits = ggml.ggml_internal_get_type_traits(ggml_type)
type_traits.to_float(ggml.ggml_get_data(tensor), ggml.ggml_get_data_f32(ntensor), n_els)
tensor, ggml_type = ntensor, ggml.GGML_TYPE_F32
np_type = ggml_type_to_np_dtype[ggml_type]
ctypes_type = np_dtype_to_ctype.get(np_type, None) or np.ctypeslib.as_ctypes_type(np_type)
data = ggml.ggml_get_data(tensor)
if data is None: raise ValueError("tensor data is None")
arr = (ctypes_type * ggml.ggml_nelements(tensor)).from_address(data)
strides = tuple(reversed(tensor.contents.nb[:n_dims]))
output = np.ctypeslib.as_array(arr)
output.dtype = np_type
return np.lib.stride_tricks.as_strided(output, shape=shape, strides=strides), ctx
@unittest.skipIf(any(not is_dtype_supported(t) for t in [ dtypes.uint8, dtypes.half ]), "Backend must support uint8 and half")
class TestGGUF(unittest.TestCase):
def setUp(self) -> None:
params = ggml.ggml_init_params(mem_size=0, mem_buffer=None, no_alloc=False)
self.ctx = ctypes.cast(ggml.ggml_init(params), ctypes.POINTER(ctypes.c_void_p))
def tearDown(self) -> None: ggml.ggml_free(self.ctx)
def test_load_tinyllama_q8_0(self): self._test_load_gguf("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q8_0.gguf?download=true")
def test_load_tinyllama_q4_0(self): self._test_load_gguf("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf?download=true")
def test_load_gpt2_q4_1(self): self._test_load_gguf("https://huggingface.co/PrunaAI/gpt2-GGUF-smashed/resolve/main/gpt2.Q4_1.gguf?download=true")
def test_load_sample_q6_k(self): self._test_load_gguf("https://huggingface.co/Isotr0py/test-gguf-sample/resolve/main/Quant_Q6_K_1024.gguf?download=true")
def test_dequantization_q4_0(self): self._test_dequantization(ggml.GGML_TYPE_Q4_0)
def test_dequantization_q4_1(self): self._test_dequantization(ggml.GGML_TYPE_Q4_1)
def test_dequantization_q8_0(self): self._test_dequantization(ggml.GGML_TYPE_Q8_0)
def test_dequantization_q6_k(self): self._test_dequantization(ggml.GGML_TYPE_Q6_K)
def _test_dequantization(self, ttype: int):
type_traits = ggml.ggml_internal_get_type_traits(ttype)
n_el, n_bytes = ggml_test_block_count * type_traits.blck_size, ggml_test_block_count * type_traits.type_size
data_in = (np.random.random((n_el,)).astype(np.float32) * 100 - 50).ctypes.data_as(ctypes.POINTER(ctypes.c_float))
c_q_data, c_dq_data = (ctypes.c_char * n_bytes)(0), (ctypes.c_float * n_el)(0)
type_traits.from_float(data_in, c_q_data, n_el)
type_traits.to_float(c_q_data, c_dq_data, n_el)
q_tensor = Tensor(np.frombuffer(c_q_data, dtype=np.uint8, count=n_bytes))
dq_tensor = ggml_data_to_tensor(q_tensor, n_el, ttype).reshape(n_el)
np.testing.assert_equal(dq_tensor.numpy(), np.frombuffer(c_dq_data, dtype=np.float32))
def _test_load_gguf(self, url: str):
fp = fetch(url)
model_size = os.stat(fp).st_size
gguf_tensor = Tensor.empty(model_size, dtype=dtypes.uint8, device=f"disk:{fp}").to(Device.DEFAULT)
kv_data, tensors = load_gguf(gguf_tensor)
gguf_params = ggml.gguf_init_params(ctx=self.ctx, no_alloc=False)
gguf_ctx = ggml.gguf_init_from_file(str(fp).encode("utf8"), gguf_params)
param_ctx = gguf_params.ctx.contents.value
for ggml_tensor_idx in range(ggml.gguf_get_n_tensors(gguf_ctx)):
tensor_name = ggml.gguf_get_tensor_name(gguf_ctx, ggml_tensor_idx)
ggml_tensor = ggml.ggml_get_tensor(param_ctx, tensor_name)
ggml_tensor_numpy, temp_ctx = ggml_tensor_to_numpy(ggml_tensor)
tensor = tensors.get(tensor_name.decode("utf-8"))
np.testing.assert_equal(tensor.numpy(), ggml_tensor_numpy)
if temp_ctx is not None: ggml.ggml_free(temp_ctx)
for gguf_key_id in range(ggml.gguf_get_n_kv(gguf_ctx)):
v = kv_data[ggml.gguf_get_key(gguf_ctx, gguf_key_id).decode("utf-8")]
v_type = ggml.gguf_get_kv_type(gguf_ctx, gguf_key_id)
if (get_fn := gguf_val_getters[v_type]) is not None: self.assertEqual(get_fn(gguf_ctx, gguf_key_id), v)
ggml.gguf_free(gguf_ctx)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,5 +1,5 @@
import os, json, pathlib, zipfile, pickle, tarfile, struct
from typing import Dict, Union, List, Optional, Any, Tuple
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools
from typing import Dict, Union, List, Optional, Any, Tuple, Callable
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
@@ -226,3 +226,64 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
base_offset += 8 + lens[i]
f.seek(rwd)
return TorchPickle(f).load()
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int):
bc_dtype = { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type, None)
if bc_dtype is not None: return t[:bc_dtype.itemsize * n].bitcast(bc_dtype)
def q_to_uint8(t: Tensor, b: int) -> Tensor:
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
return t.unsqueeze(-1).expand((*t.shape,8//b)).div(shift_tensor, upcast=False).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
blk_nel, blk_nb = { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }[ggml_type]
blocks = t[:(n//blk_nel)*blk_nb].reshape((-1, blk_nb))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 3:
d, m = tuple(blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((blocks.shape[0], 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
def load_gguf(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
"""
Loads a gguf file from a tensor.
```python
fn = "Meta-Llama-3-8B-Instruct.Q4_0.gguf"
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
kv_data, state_dict = load_gguf(gguf_tensor)
```
"""
if tensor.dtype != dtypes.uint8 or len(tensor.shape) != 1: raise ValueError("GGUF tensor must be 1d and of dtype uint8!")
pos, read_buffer, rb_start, kv_data, state_dict = 0, memoryview(bytes()), 0, {}, {}
def read_bytes(n: int):
nonlocal pos, read_buffer, rb_start
if rb_start + len(read_buffer) < pos + n: rb_start, read_buffer = pos, tensor[pos:(pos+max(n, 1000_000))].data()
return read_buffer[pos-rb_start:(pos:=pos+n)-rb_start]
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, read_bytes(n))[0]
def read_str(): return str(read_bytes(read_uint64()), "utf-8")
def read_arr():
reader, n = readers[read_int32()], read_uint64()
return [ reader() for _ in range(n) ]
readers: Dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t, f, nb in [ (0,"c",1),
(1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
magic, version, n_tensors, n_kv = read_bytes(4), read_int32(), read_int64(), read_int64()
if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
for _ in range(n_kv):
k, typ = read_str(), read_int32()
kv_data[k] = readers[typ]()
t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
alignment = kv_data.get("general.alignment", 32)
data_start = pos = pos + (alignment - pos % alignment if pos % alignment != 0 else 0)
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
return kv_data, state_dict