mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
special tol when f16 and bf16 are tc input dtypes (#8183)
This commit is contained in:
@@ -40,8 +40,8 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
|
||||
assert len([uop for uop in k.uops if uop.op is Ops.WMMA]) > 0, "tensor core not triggered"
|
||||
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
np_c = np_a @ np_b
|
||||
if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
|
||||
elif dtype_out == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2
|
||||
if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
|
||||
elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2
|
||||
else: tc_atol, tc_rtol = 5e-3, 1e-4
|
||||
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user