From 35e3983840ebf9d95a1dd412ff9ddb8e39f69997 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 8 Apr 2026 17:16:19 +0800 Subject: [PATCH] Add Q5_0, Q5_1, and bfloat16 GGUF types (#15644) --- test/unit/test_gguf.py | 33 ++++++++++++++++++++++----------- tinygrad/nn/state.py | 23 +++++++++++++++++++---- 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/test/unit/test_gguf.py b/test/unit/test_gguf.py index 5b7f2f4de5..ee042b72e4 100644 --- a/test/unit/test_gguf.py +++ b/test/unit/test_gguf.py @@ -30,11 +30,15 @@ class TestGGUF(unittest.TestCase): def test_dequantization_q4_0(self): self._test_dequantization(GGMLQuantizationType.Q4_0) def test_dequantization_q4_1(self): self._test_dequantization(GGMLQuantizationType.Q4_1) + def test_dequantization_q5_0(self): self._test_dequantization(GGMLQuantizationType.Q5_0) + def test_dequantization_q5_1(self): self._test_dequantization(GGMLQuantizationType.Q5_1) def test_dequantization_q8_0(self): self._test_dequantization(GGMLQuantizationType.Q8_0) def test_dequantization_q4_k(self): self._test_dequantization(GGMLQuantizationType.Q4_K) def test_dequantization_q5_k(self): self._test_dequantization(GGMLQuantizationType.Q5_K) def test_dequantization_q6_k(self): self._test_dequantization(GGMLQuantizationType.Q6_K) def test_dequantization_mxfp4(self): self._test_dequantization(GGMLQuantizationType.MXFP4) + @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "Backend must support bfloat16") + def test_dequantization_bf16(self): self._test_dequantization(GGMLQuantizationType.BF16) def test_dequantization_mxfp4_old(self): def encode(nibbles, E): packed = [(low & 0xF) | ((high & 0xF) << 4) for low, high in zip(nibbles[:16], nibbles[16:])] @@ -120,17 +124,21 @@ class TestGGUF(unittest.TestCase): class TestGGUFGEMV(unittest.TestCase): def _test_gguf_gemv(self, qtype: GGMLQuantizationType): block_size, type_size = GGML_QUANT_SIZES[qtype] - rows, cols = 8192, 2048 + rows, cols = (1024, 512) if qtype == GGMLQuantizationType.BF16 else (8192, 2048) n_blocks = rows * cols // block_size rng = np.random.default_rng(42) - # generate random quantized blocks with valid fp16 scale fields (random bytes can produce NaN scales) - q_data = rng.integers(0, 256, size=n_blocks * type_size, dtype=np.uint8).reshape(n_blocks, type_size) - scales = np.float16(rng.standard_normal(n_blocks * 4)).view(np.uint8).reshape(n_blocks, -1) - if qtype == GGMLQuantizationType.Q8_0: q_data[:, :2] = scales[:, :2] # d at offset 0 - elif qtype in (GGMLQuantizationType.Q4_K, GGMLQuantizationType.Q5_K): q_data[:, :4] = scales[:, :4] # d, dmin at offset 0 - elif qtype == GGMLQuantizationType.Q6_K: q_data[:, -2:] = scales[:, :2] # d at end - elif qtype == GGMLQuantizationType.MXFP4: q_data[:, 0] = rng.integers(120, 136, size=n_blocks, dtype=np.uint8) # constrain byte0 - q_data = q_data.flatten() + if qtype == GGMLQuantizationType.BF16: + q_data = (rng.standard_normal(rows * cols).astype(np.float32).view(np.uint32) >> 16).astype(np.uint16).view(np.uint8) + else: + # generate random quantized blocks with valid fp16 scale fields (random bytes can produce NaN scales) + q_data = rng.integers(0, 256, size=n_blocks * type_size, dtype=np.uint8).reshape(n_blocks, type_size) + scales = np.float16(rng.standard_normal(n_blocks * 4)).view(np.uint8).reshape(n_blocks, -1) + if qtype in (GGMLQuantizationType.Q5_0, GGMLQuantizationType.Q8_0): q_data[:, :2] = scales[:, :2] # d at offset 0 + elif qtype in (GGMLQuantizationType.Q5_1, GGMLQuantizationType.Q4_K, GGMLQuantizationType.Q5_K): + q_data[:, :4] = scales[:, :4] # d, m/dmin at offset 0 + elif qtype == GGMLQuantizationType.Q6_K: q_data[:, -2:] = scales[:, :2] # d at end + elif qtype == GGMLQuantizationType.MXFP4: q_data[:, 0] = rng.integers(120, 136, size=n_blocks, dtype=np.uint8) # constrain byte0 + q_data = q_data.flatten() ref = dequantize(q_data, qtype).reshape(rows, cols) # build a minimal gguf in memory: header + 1 tensor info + aligned data @@ -148,15 +156,18 @@ class TestGGUFGEMV(unittest.TestCase): x = rng.standard_normal(cols).astype(np.float32) np.testing.assert_allclose((tensors["weight"] @ Tensor(x)).numpy(), ref @ x, atol=1e-2, rtol=1e-2) - # can only expect the weights to be identical if we really support float16 (ie. not decompositions) - if is_dtype_supported(dtypes.half): np.testing.assert_equal(tensors["weight"].numpy(), ref) + if qtype == GGMLQuantizationType.BF16 or is_dtype_supported(dtypes.half): np.testing.assert_equal(tensors["weight"].numpy(), ref) assert np.isfinite(ref).all() and np.isfinite(tensors["weight"].numpy()).all(), f"{qtype.name} has NaN/Inf" def test_gguf_gemv_q8_0(self): self._test_gguf_gemv(GGMLQuantizationType.Q8_0) + def test_gguf_gemv_q5_0(self): self._test_gguf_gemv(GGMLQuantizationType.Q5_0) + def test_gguf_gemv_q5_1(self): self._test_gguf_gemv(GGMLQuantizationType.Q5_1) def test_gguf_gemv_q4_k(self): self._test_gguf_gemv(GGMLQuantizationType.Q4_K) def test_gguf_gemv_q5_k(self): self._test_gguf_gemv(GGMLQuantizationType.Q5_K) def test_gguf_gemv_q6_k(self): self._test_gguf_gemv(GGMLQuantizationType.Q6_K) def test_gguf_gemv_mxfp4(self): self._test_gguf_gemv(GGMLQuantizationType.MXFP4) + @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "Backend must support bfloat16") + def test_gguf_gemv_bf16(self): self._test_gguf_gemv(GGMLQuantizationType.BF16) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 5c83d6c3e7..fe8736c117 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -297,13 +297,19 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: """ Converts ggml tensor data to a tinygrad tensor. - Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18) - Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q4_K (id: 12), Q5_K (id: 13), Q6_K (id: 14), MXFP4 (id: 39) + Supported native types: float32 (id: 0), float16 (id: 1), bfloat16 (id: 30), + int8 (id: 16), int16 (id: 17), int32 (id: 18) + Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q5_0 (id: 6), + Q5_1 (id: 7), Q8_0 (id: 8), Q4_K (id: 12), Q5_K (id: 13), + Q6_K (id: 14), MXFP4 (id: 39) """ # https://github.com/ggerganov/ggml/blob/323951f1bdcdfbd5b5ff3a9a7c3770e63b1a560e/include/ggml.h#L356 # native types - if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None: + if (dtype := { + 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, + 17: dtypes.int16, 18: dtypes.int32, 30: dtypes.bfloat16, + }.get(ggml_type)) is not None: return t[:dtype.itemsize * n].contiguous().bitcast(dtype) def q_to_uint8(t: Tensor, b: int) -> Tensor: @@ -312,12 +318,21 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2) # map to (number of elements, number of bytes) - if (nelements_nbytes := { 2:(32,18), 3:(32,20), 8:(32,34), 12:(256,144), 13:(256,176), 14:(256,210), 39:(32,17) }.get(ggml_type)) is not None: + if (nelements_nbytes := { + 2:(32,18), 3:(32,20), 6:(32,22), 7:(32,24), 8:(32,34), + 12:(256,144), 13:(256,176), 14:(256,210), 39:(32,17), + }.get(ggml_type)) is not None: blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1])).contiguous() if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) if ggml_type == 3: d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ]) return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m + if ggml_type in (6, 7): + d = blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) + qh_off = 2 if ggml_type == 6 else 4 + qh = q_to_uint8(blocks[:,qh_off:qh_off+4], 1).reshape((-1, 8, 4)).transpose(-1, -2).flatten(-2).bitcast(dtypes.int8) + q = q_to_uint8(blocks[:,qh_off+4:], 4).bitcast(dtypes.int8) + qh * 16 + return q * d + (blocks[:,2:4].bitcast(dtypes.float16).cast(dtypes.float32) if ggml_type == 7 else -16 * d) if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8) # Q4_K: 256 elements per 144-byte block (d:2, dmin:2, scales:12, qs:128) # Q5_K: 256 elements per 176-byte block (d:2, dmin:2, scales:12, qh:32, qs:128)