mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
does this fix the dtype test? (#13779)
* does this fix the dtype test? * simpler
This commit is contained in:
@@ -60,7 +60,7 @@ def universal_test(a, b, dtype, op):
|
||||
ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype)
|
||||
tensor_value = (op[0](ta, tb)).numpy()
|
||||
numpy_value = op[1](ta.numpy(), tb.numpy())
|
||||
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value)
|
||||
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value.item())
|
||||
if dtype in dtypes.floats:
|
||||
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype, (1e-10, 1e-7))
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
|
||||
@@ -77,8 +77,8 @@ def universal_test_unary(a, dtype, op):
|
||||
numpy_value = op[1](ta.numpy())
|
||||
if dtype in dtypes.fp8s:
|
||||
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
|
||||
if math.isinf(numpy_value): return
|
||||
numpy_value = truncate[dtype](numpy_value)
|
||||
if math.isinf(numpy_value.item()): return
|
||||
numpy_value = truncate[dtype](numpy_value.item())
|
||||
if dtype in dtypes.floats:
|
||||
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
|
||||
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))
|
||||
|
||||
Reference in New Issue
Block a user