special tol when f16 and bf16 are tc input dtypes (#8183)

This commit is contained in:
ignaciosica
2024-12-21 13:32:26 -03:00
committed by GitHub
parent 3f83748661
commit ba0c844a83

View File

@@ -40,8 +40,8 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
assert len([uop for uop in k.uops if uop.op is Ops.WMMA]) > 0, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
np_c = np_a @ np_b
if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
elif dtype_out == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2
if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2
else: tc_atol, tc_rtol = 5e-3, 1e-4
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)