mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Add Q5_0, Q5_1, and bfloat16 GGUF types (#15644)
This commit is contained in:
@@ -30,11 +30,15 @@ class TestGGUF(unittest.TestCase):
|
||||
|
||||
def test_dequantization_q4_0(self): self._test_dequantization(GGMLQuantizationType.Q4_0)
|
||||
def test_dequantization_q4_1(self): self._test_dequantization(GGMLQuantizationType.Q4_1)
|
||||
def test_dequantization_q5_0(self): self._test_dequantization(GGMLQuantizationType.Q5_0)
|
||||
def test_dequantization_q5_1(self): self._test_dequantization(GGMLQuantizationType.Q5_1)
|
||||
def test_dequantization_q8_0(self): self._test_dequantization(GGMLQuantizationType.Q8_0)
|
||||
def test_dequantization_q4_k(self): self._test_dequantization(GGMLQuantizationType.Q4_K)
|
||||
def test_dequantization_q5_k(self): self._test_dequantization(GGMLQuantizationType.Q5_K)
|
||||
def test_dequantization_q6_k(self): self._test_dequantization(GGMLQuantizationType.Q6_K)
|
||||
def test_dequantization_mxfp4(self): self._test_dequantization(GGMLQuantizationType.MXFP4)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "Backend must support bfloat16")
|
||||
def test_dequantization_bf16(self): self._test_dequantization(GGMLQuantizationType.BF16)
|
||||
def test_dequantization_mxfp4_old(self):
|
||||
def encode(nibbles, E):
|
||||
packed = [(low & 0xF) | ((high & 0xF) << 4) for low, high in zip(nibbles[:16], nibbles[16:])]
|
||||
@@ -120,17 +124,21 @@ class TestGGUF(unittest.TestCase):
|
||||
class TestGGUFGEMV(unittest.TestCase):
|
||||
def _test_gguf_gemv(self, qtype: GGMLQuantizationType):
|
||||
block_size, type_size = GGML_QUANT_SIZES[qtype]
|
||||
rows, cols = 8192, 2048
|
||||
rows, cols = (1024, 512) if qtype == GGMLQuantizationType.BF16 else (8192, 2048)
|
||||
n_blocks = rows * cols // block_size
|
||||
rng = np.random.default_rng(42)
|
||||
# generate random quantized blocks with valid fp16 scale fields (random bytes can produce NaN scales)
|
||||
q_data = rng.integers(0, 256, size=n_blocks * type_size, dtype=np.uint8).reshape(n_blocks, type_size)
|
||||
scales = np.float16(rng.standard_normal(n_blocks * 4)).view(np.uint8).reshape(n_blocks, -1)
|
||||
if qtype == GGMLQuantizationType.Q8_0: q_data[:, :2] = scales[:, :2] # d at offset 0
|
||||
elif qtype in (GGMLQuantizationType.Q4_K, GGMLQuantizationType.Q5_K): q_data[:, :4] = scales[:, :4] # d, dmin at offset 0
|
||||
elif qtype == GGMLQuantizationType.Q6_K: q_data[:, -2:] = scales[:, :2] # d at end
|
||||
elif qtype == GGMLQuantizationType.MXFP4: q_data[:, 0] = rng.integers(120, 136, size=n_blocks, dtype=np.uint8) # constrain byte0
|
||||
q_data = q_data.flatten()
|
||||
if qtype == GGMLQuantizationType.BF16:
|
||||
q_data = (rng.standard_normal(rows * cols).astype(np.float32).view(np.uint32) >> 16).astype(np.uint16).view(np.uint8)
|
||||
else:
|
||||
# generate random quantized blocks with valid fp16 scale fields (random bytes can produce NaN scales)
|
||||
q_data = rng.integers(0, 256, size=n_blocks * type_size, dtype=np.uint8).reshape(n_blocks, type_size)
|
||||
scales = np.float16(rng.standard_normal(n_blocks * 4)).view(np.uint8).reshape(n_blocks, -1)
|
||||
if qtype in (GGMLQuantizationType.Q5_0, GGMLQuantizationType.Q8_0): q_data[:, :2] = scales[:, :2] # d at offset 0
|
||||
elif qtype in (GGMLQuantizationType.Q5_1, GGMLQuantizationType.Q4_K, GGMLQuantizationType.Q5_K):
|
||||
q_data[:, :4] = scales[:, :4] # d, m/dmin at offset 0
|
||||
elif qtype == GGMLQuantizationType.Q6_K: q_data[:, -2:] = scales[:, :2] # d at end
|
||||
elif qtype == GGMLQuantizationType.MXFP4: q_data[:, 0] = rng.integers(120, 136, size=n_blocks, dtype=np.uint8) # constrain byte0
|
||||
q_data = q_data.flatten()
|
||||
ref = dequantize(q_data, qtype).reshape(rows, cols)
|
||||
|
||||
# build a minimal gguf in memory: header + 1 tensor info + aligned data
|
||||
@@ -148,15 +156,18 @@ class TestGGUFGEMV(unittest.TestCase):
|
||||
|
||||
x = rng.standard_normal(cols).astype(np.float32)
|
||||
np.testing.assert_allclose((tensors["weight"] @ Tensor(x)).numpy(), ref @ x, atol=1e-2, rtol=1e-2)
|
||||
# can only expect the weights to be identical if we really support float16 (ie. not decompositions)
|
||||
if is_dtype_supported(dtypes.half): np.testing.assert_equal(tensors["weight"].numpy(), ref)
|
||||
if qtype == GGMLQuantizationType.BF16 or is_dtype_supported(dtypes.half): np.testing.assert_equal(tensors["weight"].numpy(), ref)
|
||||
assert np.isfinite(ref).all() and np.isfinite(tensors["weight"].numpy()).all(), f"{qtype.name} has NaN/Inf"
|
||||
|
||||
def test_gguf_gemv_q8_0(self): self._test_gguf_gemv(GGMLQuantizationType.Q8_0)
|
||||
def test_gguf_gemv_q5_0(self): self._test_gguf_gemv(GGMLQuantizationType.Q5_0)
|
||||
def test_gguf_gemv_q5_1(self): self._test_gguf_gemv(GGMLQuantizationType.Q5_1)
|
||||
def test_gguf_gemv_q4_k(self): self._test_gguf_gemv(GGMLQuantizationType.Q4_K)
|
||||
def test_gguf_gemv_q5_k(self): self._test_gguf_gemv(GGMLQuantizationType.Q5_K)
|
||||
def test_gguf_gemv_q6_k(self): self._test_gguf_gemv(GGMLQuantizationType.Q6_K)
|
||||
def test_gguf_gemv_mxfp4(self): self._test_gguf_gemv(GGMLQuantizationType.MXFP4)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "Backend must support bfloat16")
|
||||
def test_gguf_gemv_bf16(self): self._test_gguf_gemv(GGMLQuantizationType.BF16)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -297,13 +297,19 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
||||
"""
|
||||
Converts ggml tensor data to a tinygrad tensor.
|
||||
|
||||
Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
|
||||
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q4_K (id: 12), Q5_K (id: 13), Q6_K (id: 14), MXFP4 (id: 39)
|
||||
Supported native types: float32 (id: 0), float16 (id: 1), bfloat16 (id: 30),
|
||||
int8 (id: 16), int16 (id: 17), int32 (id: 18)
|
||||
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q5_0 (id: 6),
|
||||
Q5_1 (id: 7), Q8_0 (id: 8), Q4_K (id: 12), Q5_K (id: 13),
|
||||
Q6_K (id: 14), MXFP4 (id: 39)
|
||||
"""
|
||||
# https://github.com/ggerganov/ggml/blob/323951f1bdcdfbd5b5ff3a9a7c3770e63b1a560e/include/ggml.h#L356
|
||||
|
||||
# native types
|
||||
if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
|
||||
if (dtype := {
|
||||
0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8,
|
||||
17: dtypes.int16, 18: dtypes.int32, 30: dtypes.bfloat16,
|
||||
}.get(ggml_type)) is not None:
|
||||
return t[:dtype.itemsize * n].contiguous().bitcast(dtype)
|
||||
|
||||
def q_to_uint8(t: Tensor, b: int) -> Tensor:
|
||||
@@ -312,12 +318,21 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
||||
return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
|
||||
|
||||
# map to (number of elements, number of bytes)
|
||||
if (nelements_nbytes := { 2:(32,18), 3:(32,20), 8:(32,34), 12:(256,144), 13:(256,176), 14:(256,210), 39:(32,17) }.get(ggml_type)) is not None:
|
||||
if (nelements_nbytes := {
|
||||
2:(32,18), 3:(32,20), 6:(32,22), 7:(32,24), 8:(32,34),
|
||||
12:(256,144), 13:(256,176), 14:(256,210), 39:(32,17),
|
||||
}.get(ggml_type)) is not None:
|
||||
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1])).contiguous()
|
||||
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
|
||||
if ggml_type == 3:
|
||||
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
|
||||
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
|
||||
if ggml_type in (6, 7):
|
||||
d = blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
|
||||
qh_off = 2 if ggml_type == 6 else 4
|
||||
qh = q_to_uint8(blocks[:,qh_off:qh_off+4], 1).reshape((-1, 8, 4)).transpose(-1, -2).flatten(-2).bitcast(dtypes.int8)
|
||||
q = q_to_uint8(blocks[:,qh_off+4:], 4).bitcast(dtypes.int8) + qh * 16
|
||||
return q * d + (blocks[:,2:4].bitcast(dtypes.float16).cast(dtypes.float32) if ggml_type == 7 else -16 * d)
|
||||
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
|
||||
# Q4_K: 256 elements per 144-byte block (d:2, dmin:2, scales:12, qs:128)
|
||||
# Q5_K: 256 elements per 176-byte block (d:2, dmin:2, scales:12, qh:32, qs:128)
|
||||
|
||||
Reference in New Issue
Block a user