mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix test_dequantization_mxfp4 (#12123)
* fix test_dequantization_mxfp4 * assert_allclose * rtol
This commit is contained in:
@@ -59,7 +59,6 @@ class TestGGUF(unittest.TestCase):
|
||||
def test_dequantization_q4_1(self): self._test_dequantization(ggml.GGML_TYPE_Q4_1)
|
||||
def test_dequantization_q8_0(self): self._test_dequantization(ggml.GGML_TYPE_Q8_0)
|
||||
def test_dequantization_q6_k(self): self._test_dequantization(ggml.GGML_TYPE_Q6_K)
|
||||
@unittest.expectedFailure #does not work
|
||||
def test_dequantization_mxfp4(self):
|
||||
MXFP4 = 39
|
||||
|
||||
@@ -68,7 +67,7 @@ class TestGGUF(unittest.TestCase):
|
||||
return np.array([E] + packed, dtype=np.uint8)
|
||||
|
||||
def decode(code, E):
|
||||
sign = -1.0 if code * 0b1000 else 1.0
|
||||
sign = -1.0 if (code & 0b1000) else 1.0
|
||||
exp = (code >> 1) & 0b11
|
||||
mant = code & 0b1
|
||||
val = (1.0 + 0.5 * mant) * np.exp2(exp - 1) if exp else 0.5 * mant
|
||||
@@ -84,7 +83,8 @@ class TestGGUF(unittest.TestCase):
|
||||
expected.extend(decode(c, E) for c in codes)
|
||||
tensor = Tensor(np.concatenate(blocks))
|
||||
out = ggml_data_to_tensor(tensor, len(expected), MXFP4)
|
||||
self.assertListEqual(out.numpy().tolist(), np.array(expected, dtype=np.float32).tolist())
|
||||
# TODO: should this be exact equal? somehow failed on CI
|
||||
np.testing.assert_allclose(out.numpy(), expected, atol=0.0, rtol=1e-6)
|
||||
|
||||
def test_expected_failure_unknown_type(self):
|
||||
with self.assertRaises(ValueError):
|
||||
|
||||
Reference in New Issue
Block a user