From ba0c844a83ccc2e0c3e52fc619946cf2272ef9a4 Mon Sep 17 00:00:00 2001 From: ignaciosica Date: Sat, 21 Dec 2024 13:32:26 -0300 Subject: [PATCH] special tol when f16 and bf16 are tc input dtypes (#8183) --- test/test_linearizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index c7c441ebca..0774f11018 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -40,8 +40,8 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi assert len([uop for uop in k.uops if uop.op is Ops.WMMA]) > 0, "tensor core not triggered" assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" np_c = np_a @ np_b - if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3 - elif dtype_out == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2 + if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3 + elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2 else: tc_atol, tc_rtol = 5e-3, 1e-4 np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)