atol in test_call_gemm (#14480)

flaky
This commit is contained in:
chenyu
2026-02-01 11:24:58 -05:00
committed by GitHub
parent 5705398a1f
commit 02afae04f4

View File

@@ -42,7 +42,7 @@ class TestCall(unittest.TestCase):
b = Tensor.randn(K, N)
Tensor.realize(a, b)
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
@unittest.skip("needs GEMM on mixins")
def test_call_gemm_uop(self):
@@ -56,7 +56,7 @@ class TestCall(unittest.TestCase):
y = UOp.param(1, dtypes.float, shape=(K, N))
c = Tensor.call(a, b, fxn=x@y)
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
if __name__ == '__main__':
unittest.main()