fix test_dequantization_mxfp4 (#12123)

* fix test_dequantization_mxfp4

* assert_allclose

* rtol
This commit is contained in:
chenyu
2025-09-11 14:22:06 -04:00
committed by GitHub
parent 520e2e0727
commit 3a83b56da5

View File

@@ -59,7 +59,6 @@ class TestGGUF(unittest.TestCase):
def test_dequantization_q4_1(self): self._test_dequantization(ggml.GGML_TYPE_Q4_1)
def test_dequantization_q8_0(self): self._test_dequantization(ggml.GGML_TYPE_Q8_0)
def test_dequantization_q6_k(self): self._test_dequantization(ggml.GGML_TYPE_Q6_K)
@unittest.expectedFailure #does not work
def test_dequantization_mxfp4(self):
MXFP4 = 39
@@ -68,7 +67,7 @@ class TestGGUF(unittest.TestCase):
return np.array([E] + packed, dtype=np.uint8)
def decode(code, E):
sign = -1.0 if code * 0b1000 else 1.0
sign = -1.0 if (code & 0b1000) else 1.0
exp = (code >> 1) & 0b11
mant = code & 0b1
val = (1.0 + 0.5 * mant) * np.exp2(exp - 1) if exp else 0.5 * mant
@@ -84,7 +83,8 @@ class TestGGUF(unittest.TestCase):
expected.extend(decode(c, E) for c in codes)
tensor = Tensor(np.concatenate(blocks))
out = ggml_data_to_tensor(tensor, len(expected), MXFP4)
self.assertListEqual(out.numpy().tolist(), np.array(expected, dtype=np.float32).tolist())
# TODO: should this be exact equal? somehow failed on CI
np.testing.assert_allclose(out.numpy(), expected, atol=0.0, rtol=1e-6)
def test_expected_failure_unknown_type(self):
with self.assertRaises(ValueError):