does this fix the dtype test? (#13779)

* does this fix the dtype test?

* simpler
This commit is contained in:
George Hotz
2025-12-20 17:31:46 -04:00
committed by GitHub
parent 5228f7bd06
commit 59c02dd87f

View File

@@ -60,7 +60,7 @@ def universal_test(a, b, dtype, op):
ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype)
tensor_value = (op[0](ta, tb)).numpy()
numpy_value = op[1](ta.numpy(), tb.numpy())
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value)
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value.item())
if dtype in dtypes.floats:
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype, (1e-10, 1e-7))
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
@@ -77,8 +77,8 @@ def universal_test_unary(a, dtype, op):
numpy_value = op[1](ta.numpy())
if dtype in dtypes.fp8s:
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
if math.isinf(numpy_value): return
numpy_value = truncate[dtype](numpy_value)
if math.isinf(numpy_value.item()): return
numpy_value = truncate[dtype](numpy_value.item())
if dtype in dtypes.floats:
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))