[TEST] Added matmul config for testing (#1758)

This commit is contained in:
Zahi Moudallal
2023-06-22 13:31:37 -07:00
committed by GitHub
parent 8d566e4196
commit ca4f242c9b

View File

@@ -67,6 +67,7 @@ def f8_to_f16(x):
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE, DTYPE),
(128, 256, 64, 1, 8, 3, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
@@ -131,6 +132,9 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
th_c = torch.matmul(a, b)
try:
tt_c = triton.ops.matmul(a, b)
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
atol, rtol = 1e-2, 0
if ADTYPE == torch.bfloat16 or BDTYPE == torch.bfloat16:
atol, rtol = 3.5e-2, 0
torch.testing.assert_allclose(th_c, tt_c, atol=atol, rtol=rtol)
except triton.OutOfResources as e:
pytest.skip(str(e))