mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TEST] Added matmul config for testing (#1758)
This commit is contained in:
@@ -67,6 +67,7 @@ def f8_to_f16(x):
|
||||
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE, DTYPE),
|
||||
(128, 256, 64, 1, 8, 3, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
@@ -131,6 +132,9 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
th_c = torch.matmul(a, b)
|
||||
try:
|
||||
tt_c = triton.ops.matmul(a, b)
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
|
||||
atol, rtol = 1e-2, 0
|
||||
if ADTYPE == torch.bfloat16 or BDTYPE == torch.bfloat16:
|
||||
atol, rtol = 3.5e-2, 0
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=atol, rtol=rtol)
|
||||
except triton.OutOfResources as e:
|
||||
pytest.skip(str(e))
|
||||
|
||||
Reference in New Issue
Block a user