diff --git a/python/test/regression/test_functional_regressions.py b/python/test/regression/test_functional_regressions.py index 684cbfb4d..b873db7a3 100644 --- a/python/test/regression/test_functional_regressions.py +++ b/python/test/regression/test_functional_regressions.py @@ -227,4 +227,4 @@ def test_iv_dependent_matmul(type): b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, num_stages=num_stages) - torch.testing.assert_allclose(torch_output, triton_output, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index fd7c14e6c..d66c3b795 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -149,7 +149,7 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) th_c = torch.matmul(a, b) - torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit @@ -300,7 +300,7 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K enable_warp_specialization=True) th_c = torch.matmul(a, b) - torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit @@ -456,7 +456,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N enable_warp_specialization=True) th_c = torch.matmul(a, b) - torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index 147249076..3cc4bdced 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -99,4 +99,4 @@ def test_block_ptr_matmul_no_scf(shape, num_warps): BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, num_warps=num_warps) golden = torch.matmul(a, b) - torch.testing.assert_allclose(c, golden) + torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index dbeca94c4..f962facb4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -881,7 +881,7 @@ def test_abs_fp8(in_dtype, device): f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) expect = f32_tensor.abs() actual_f8 = convert_float_to_float32(out_f8, in_dtype) - torch.testing.assert_allclose(actual_f8, expect) + torch.testing.assert_close(actual_f8, expect, equal_nan=True) # ---------------- @@ -2594,7 +2594,7 @@ def test_masked_load(dtype_str, size, size_diff, num_ctas, device): reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device))) # print((output - reference_out).nonzero()) - torch.testing.assert_allclose(output, reference_out) + torch.testing.assert_close(output, reference_out) # Testing masked loads with an intermate copy to shared memory run. @@ -2649,7 +2649,7 @@ def test_masked_load_shared_memory(dtype, device): M=M, N=N, K=K) reference_out = torch.matmul(in1, in2) - torch.testing.assert_allclose(out, reference_out, atol=1e-2, rtol=0) + torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) @pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 5f94cd8b3..7e6f820a3 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -86,9 +86,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= da_tri = a_tri.grad db_tri = b_tri.grad # compare - torch.testing.assert_allclose(c_ref, c_tri) - torch.testing.assert_allclose(da_ref, da_tri) - torch.testing.assert_allclose(db_ref, db_tri) + torch.testing.assert_close(c_ref, c_tri) + torch.testing.assert_close(da_ref, da_tri) + torch.testing.assert_close(db_ref, db_tri) configs = [ @@ -138,8 +138,8 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4): out_tri.backward(dout_tri) da_tri = a_tri.grad # compare - torch.testing.assert_allclose(out_tri, out_ref) - torch.testing.assert_allclose(da_tri, da_ref) + torch.testing.assert_close(out_tri, out_ref, equal_nan=True) + torch.testing.assert_close(da_tri, da_ref, equal_nan=True) @pytest.mark.parametrize("block", [16, 32, 64]) @@ -195,9 +195,9 @@ def test_attention_fwd_bwd( # comparison # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") - torch.testing.assert_allclose(loss, torch_loss, atol=1e-3, rtol=0) + torch.testing.assert_close(loss, torch_loss, atol=1e-3, rtol=0) for g1, g2 in zip(grads, torch_grads): - torch.testing.assert_allclose(g1, g2) + torch.testing.assert_close(g1, g2) @pytest.mark.parametrize("block", [16, 32, 64]) diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index f4e40d3a6..be59fc42a 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -25,7 +25,7 @@ def test_op(M, N, dtype, mode): tt_y = triton.ops.cross_entropy(x, idx) th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) if mode == 'forward': - torch.testing.assert_allclose(th_y, tt_y) + torch.testing.assert_close(th_y, tt_y) # backward pass elif mode == 'backward': dy = torch.randn_like(tt_y) @@ -37,4 +37,4 @@ def test_op(M, N, dtype, mode): th_y.backward(dy) th_dx = x.grad.clone() - torch.testing.assert_allclose(th_dx, tt_dx) + torch.testing.assert_close(th_dx, tt_dx) diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index 4bacf53b7..b6f74f2fc 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -51,7 +51,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): tri_dq, q.grad = q.grad.clone(), None # compare atol = 1e-1 if dtype == torch.bfloat16 else 1e-2 - torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0) - torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0) - torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0) - torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0) + torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0) + torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) + torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) diff --git a/python/test/unit/operators/test_inductor.py b/python/test/unit/operators/test_inductor.py index f7e2ce2aa..fa157d2c9 100644 --- a/python/test/unit/operators/test_inductor.py +++ b/python/test/unit/operators/test_inductor.py @@ -52,7 +52,7 @@ def test_normalization_with_remat(): arg8_1 = torch.rand(64, device="cuda") arg9_1 = torch.rand(64, device="cuda") triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048) - torch.testing.assert_allclose(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) + torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) def test_avg_pool_bw(): @@ -152,4 +152,4 @@ def test_avg_pool_bw(): out_ref[:, :, 1:7, 0::7] = 2 / 3 out_ref[:, :, 0::7, 1:7] = 2 / 3 out_ref[:, :, 0::7, 0::7] = 4 / 9 - torch.testing.assert_allclose(out, out_ref) + torch.testing.assert_close(out, out_ref) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index a7afa02f1..19b5e0f05 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -177,6 +177,6 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, if b_fp8: b = triton.reinterpret(b, getattr(tl, BDTYPE)) tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32) - torch.testing.assert_allclose(th_c, tt_c, atol=0, rtol=0) + torch.testing.assert_close(th_c, tt_c) except triton.OutOfResources as e: pytest.skip(str(e))