mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[TESTS] replace deprecated torch.testing.assert_allclose (#2250)
Prior to this PR, matmul on sm_89 (RTX 4070) (`test/unit/operators/test_matmul.py::test_op`) would result in test failure due to too strict atol/rtol. To avoid having to choose strictness ourselves, and to have better defaults based on dtype, use the non-deprecated torch testing util. See: https://github.com/pytorch/pytorch/issues/61844 Replace: https://github.com/openai/triton/pull/2242
This commit is contained in:
@@ -227,4 +227,4 @@ def test_iv_dependent_matmul(type):
|
||||
b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
type=type, num_stages=num_stages)
|
||||
torch.testing.assert_allclose(torch_output, triton_output, rtol=1e-2, atol=1e-2)
|
||||
torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@@ -149,7 +149,7 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO
|
||||
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
|
||||
torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -300,7 +300,7 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
|
||||
enable_warp_specialization=True)
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
|
||||
torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -456,7 +456,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
|
||||
enable_warp_specialization=True)
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
|
||||
torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@@ -99,4 +99,4 @@ def test_block_ptr_matmul_no_scf(shape, num_warps):
|
||||
BLOCK_M=m, BLOCK_N=n, BLOCK_K=k,
|
||||
num_warps=num_warps)
|
||||
golden = torch.matmul(a, b)
|
||||
torch.testing.assert_allclose(c, golden)
|
||||
torch.testing.assert_close(c, golden, check_dtype=False)
|
||||
|
||||
@@ -881,7 +881,7 @@ def test_abs_fp8(in_dtype, device):
|
||||
f32_tensor = convert_float_to_float32(f8_tensor, in_dtype)
|
||||
expect = f32_tensor.abs()
|
||||
actual_f8 = convert_float_to_float32(out_f8, in_dtype)
|
||||
torch.testing.assert_allclose(actual_f8, expect)
|
||||
torch.testing.assert_close(actual_f8, expect, equal_nan=True)
|
||||
|
||||
|
||||
# ----------------
|
||||
@@ -2594,7 +2594,7 @@ def test_masked_load(dtype_str, size, size_diff, num_ctas, device):
|
||||
|
||||
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
# print((output - reference_out).nonzero())
|
||||
torch.testing.assert_allclose(output, reference_out)
|
||||
torch.testing.assert_close(output, reference_out)
|
||||
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
@@ -2649,7 +2649,7 @@ def test_masked_load_shared_memory(dtype, device):
|
||||
M=M, N=N, K=K)
|
||||
|
||||
reference_out = torch.matmul(in1, in2)
|
||||
torch.testing.assert_allclose(out, reference_out, atol=1e-2, rtol=0)
|
||||
torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
|
||||
@@ -86,9 +86,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
da_tri = a_tri.grad
|
||||
db_tri = b_tri.grad
|
||||
# compare
|
||||
torch.testing.assert_allclose(c_ref, c_tri)
|
||||
torch.testing.assert_allclose(da_ref, da_tri)
|
||||
torch.testing.assert_allclose(db_ref, db_tri)
|
||||
torch.testing.assert_close(c_ref, c_tri)
|
||||
torch.testing.assert_close(da_ref, da_tri)
|
||||
torch.testing.assert_close(db_ref, db_tri)
|
||||
|
||||
|
||||
configs = [
|
||||
@@ -138,8 +138,8 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
||||
out_tri.backward(dout_tri)
|
||||
da_tri = a_tri.grad
|
||||
# compare
|
||||
torch.testing.assert_allclose(out_tri, out_ref)
|
||||
torch.testing.assert_allclose(da_tri, da_ref)
|
||||
torch.testing.assert_close(out_tri, out_ref, equal_nan=True)
|
||||
torch.testing.assert_close(da_tri, da_ref, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
@@ -195,9 +195,9 @@ def test_attention_fwd_bwd(
|
||||
|
||||
# comparison
|
||||
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||
torch.testing.assert_allclose(loss, torch_loss, atol=1e-3, rtol=0)
|
||||
torch.testing.assert_close(loss, torch_loss, atol=1e-3, rtol=0)
|
||||
for g1, g2 in zip(grads, torch_grads):
|
||||
torch.testing.assert_allclose(g1, g2)
|
||||
torch.testing.assert_close(g1, g2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_op(M, N, dtype, mode):
|
||||
tt_y = triton.ops.cross_entropy(x, idx)
|
||||
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
|
||||
if mode == 'forward':
|
||||
torch.testing.assert_allclose(th_y, tt_y)
|
||||
torch.testing.assert_close(th_y, tt_y)
|
||||
# backward pass
|
||||
elif mode == 'backward':
|
||||
dy = torch.randn_like(tt_y)
|
||||
@@ -37,4 +37,4 @@ def test_op(M, N, dtype, mode):
|
||||
th_y.backward(dy)
|
||||
th_dx = x.grad.clone()
|
||||
|
||||
torch.testing.assert_allclose(th_dx, tt_dx)
|
||||
torch.testing.assert_close(th_dx, tt_dx)
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
|
||||
torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
|
||||
@@ -52,7 +52,7 @@ def test_normalization_with_remat():
|
||||
arg8_1 = torch.rand(64, device="cuda")
|
||||
arg9_1 = torch.rand(64, device="cuda")
|
||||
triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
|
||||
torch.testing.assert_allclose(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)
|
||||
torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)
|
||||
|
||||
|
||||
def test_avg_pool_bw():
|
||||
@@ -152,4 +152,4 @@ def test_avg_pool_bw():
|
||||
out_ref[:, :, 1:7, 0::7] = 2 / 3
|
||||
out_ref[:, :, 0::7, 1:7] = 2 / 3
|
||||
out_ref[:, :, 0::7, 0::7] = 4 / 9
|
||||
torch.testing.assert_allclose(out, out_ref)
|
||||
torch.testing.assert_close(out, out_ref)
|
||||
|
||||
@@ -177,6 +177,6 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
if b_fp8:
|
||||
b = triton.reinterpret(b, getattr(tl, BDTYPE))
|
||||
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32)
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=0, rtol=0)
|
||||
torch.testing.assert_close(th_c, tt_c)
|
||||
except triton.OutOfResources as e:
|
||||
pytest.skip(str(e))
|
||||
|
||||
Reference in New Issue
Block a user