[TESTS] replace deprecated torch.testing.assert_allclose (#2250)

Prior to this PR, matmul on sm_89 (RTX 4070)
(`test/unit/operators/test_matmul.py::test_op`) would result in test
failure due to too strict atol/rtol.

To avoid having to choose strictness ourselves, and to have better
defaults based on dtype, use the non-deprecated torch testing util.

See: https://github.com/pytorch/pytorch/issues/61844

Replace: https://github.com/openai/triton/pull/2242
This commit is contained in:
jon-chuang
2023-09-12 03:31:17 +08:00
committed by GitHub
parent 28d4c3bdb4
commit 5231d57c71
9 changed files with 24 additions and 24 deletions

View File

@@ -227,4 +227,4 @@ def test_iv_dependent_matmul(type):
b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
type=type, num_stages=num_stages)
torch.testing.assert_allclose(torch_output, triton_output, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2)

View File

@@ -149,7 +149,7 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
th_c = torch.matmul(a, b)
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)
@triton.jit
@@ -300,7 +300,7 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
enable_warp_specialization=True)
th_c = torch.matmul(a, b)
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)
@triton.jit
@@ -456,7 +456,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
enable_warp_specialization=True)
th_c = torch.matmul(a, b)
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)
@triton.jit

View File

@@ -99,4 +99,4 @@ def test_block_ptr_matmul_no_scf(shape, num_warps):
BLOCK_M=m, BLOCK_N=n, BLOCK_K=k,
num_warps=num_warps)
golden = torch.matmul(a, b)
torch.testing.assert_allclose(c, golden)
torch.testing.assert_close(c, golden, check_dtype=False)

View File

@@ -881,7 +881,7 @@ def test_abs_fp8(in_dtype, device):
f32_tensor = convert_float_to_float32(f8_tensor, in_dtype)
expect = f32_tensor.abs()
actual_f8 = convert_float_to_float32(out_f8, in_dtype)
torch.testing.assert_allclose(actual_f8, expect)
torch.testing.assert_close(actual_f8, expect, equal_nan=True)
# ----------------
@@ -2594,7 +2594,7 @@ def test_masked_load(dtype_str, size, size_diff, num_ctas, device):
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
# print((output - reference_out).nonzero())
torch.testing.assert_allclose(output, reference_out)
torch.testing.assert_close(output, reference_out)
# Testing masked loads with an intermate copy to shared memory run.
@@ -2649,7 +2649,7 @@ def test_masked_load_shared_memory(dtype, device):
M=M, N=N, K=K)
reference_out = torch.matmul(in1, in2)
torch.testing.assert_allclose(out, reference_out, atol=1e-2, rtol=0)
torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])

View File

@@ -86,9 +86,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
da_tri = a_tri.grad
db_tri = b_tri.grad
# compare
torch.testing.assert_allclose(c_ref, c_tri)
torch.testing.assert_allclose(da_ref, da_tri)
torch.testing.assert_allclose(db_ref, db_tri)
torch.testing.assert_close(c_ref, c_tri)
torch.testing.assert_close(da_ref, da_tri)
torch.testing.assert_close(db_ref, db_tri)
configs = [
@@ -138,8 +138,8 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
out_tri.backward(dout_tri)
da_tri = a_tri.grad
# compare
torch.testing.assert_allclose(out_tri, out_ref)
torch.testing.assert_allclose(da_tri, da_ref)
torch.testing.assert_close(out_tri, out_ref, equal_nan=True)
torch.testing.assert_close(da_tri, da_ref, equal_nan=True)
@pytest.mark.parametrize("block", [16, 32, 64])
@@ -195,9 +195,9 @@ def test_attention_fwd_bwd(
# comparison
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
torch.testing.assert_allclose(loss, torch_loss, atol=1e-3, rtol=0)
torch.testing.assert_close(loss, torch_loss, atol=1e-3, rtol=0)
for g1, g2 in zip(grads, torch_grads):
torch.testing.assert_allclose(g1, g2)
torch.testing.assert_close(g1, g2)
@pytest.mark.parametrize("block", [16, 32, 64])

View File

@@ -25,7 +25,7 @@ def test_op(M, N, dtype, mode):
tt_y = triton.ops.cross_entropy(x, idx)
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
if mode == 'forward':
torch.testing.assert_allclose(th_y, tt_y)
torch.testing.assert_close(th_y, tt_y)
# backward pass
elif mode == 'backward':
dy = torch.randn_like(tt_y)
@@ -37,4 +37,4 @@ def test_op(M, N, dtype, mode):
th_y.backward(dy)
th_dx = x.grad.clone()
torch.testing.assert_allclose(th_dx, tt_dx)
torch.testing.assert_close(th_dx, tt_dx)

View File

@@ -51,7 +51,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
tri_dq, q.grad = q.grad.clone(), None
# compare
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)
torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0)
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)

View File

@@ -52,7 +52,7 @@ def test_normalization_with_remat():
arg8_1 = torch.rand(64, device="cuda")
arg9_1 = torch.rand(64, device="cuda")
triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
torch.testing.assert_allclose(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)
torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)
def test_avg_pool_bw():
@@ -152,4 +152,4 @@ def test_avg_pool_bw():
out_ref[:, :, 1:7, 0::7] = 2 / 3
out_ref[:, :, 0::7, 1:7] = 2 / 3
out_ref[:, :, 0::7, 0::7] = 4 / 9
torch.testing.assert_allclose(out, out_ref)
torch.testing.assert_close(out, out_ref)

View File

@@ -177,6 +177,6 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
if b_fp8:
b = triton.reinterpret(b, getattr(tl, BDTYPE))
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32)
torch.testing.assert_allclose(th_c, tt_c, atol=0, rtol=0)
torch.testing.assert_close(th_c, tt_c)
except triton.OutOfResources as e:
pytest.skip(str(e))