mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[CI] [Dot] Reduced test suite (#302)
Use upstream list of test for dot op on machines with no MFMA support. This is needed to reduce time required for PR testing.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# flake8: noqa: F821,F841
|
||||
import triton.language.semantic
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
@@ -1204,30 +1205,38 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# test dot
|
||||
# ---------------
|
||||
|
||||
# @pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
|
||||
# [(*shape, 4, False, False, epilogue, allow_tf32, dtype)
|
||||
# for shape in [(64, 64, 64), (16, 16, 16)]
|
||||
# for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
# for allow_tf32 in [True, False]
|
||||
# for dtype in ['float16', 'float32']
|
||||
# if not (allow_tf32 and (dtype in ['float16']))] +
|
||||
# [(*shape_nw, col_a, col_b, 'none', allow_tf32, dtype)
|
||||
# for shape_nw in [[128, 256, 32, 8],
|
||||
# [128, 16, 32, 4],
|
||||
# [32, 128, 64, 4],
|
||||
# [128, 128, 64, 4],
|
||||
# [64, 128, 128, 4],
|
||||
# [32, 128, 64, 2],
|
||||
# [128, 128, 64, 2],
|
||||
# [64, 128, 128, 2]]
|
||||
# for allow_tf32 in [True]
|
||||
# for col_a in [True, False]
|
||||
# for col_b in [True, False]
|
||||
# for dtype in ['int8', 'float16', 'float32']])
|
||||
|
||||
|
||||
# MFMA Test Dot tests
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype",
|
||||
# FMA Test Dot tests
|
||||
[(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype)
|
||||
for shape in [(64, 64, 64), (16, 16, 16)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for in_dtype, out_dtype in [('float16', 'float16'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]
|
||||
if not (allow_tf32 and (in_dtype in ['float16']))] +
|
||||
|
||||
[(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype)
|
||||
for shape_nw in [[128, 256, 32, 8],
|
||||
[128, 16, 32, 4],
|
||||
[32, 128, 64, 4],
|
||||
[128, 128, 64, 4],
|
||||
[64, 128, 128, 4],
|
||||
[32, 128, 64, 2],
|
||||
[64, 64, 32, 4],
|
||||
[32, 32, 128, 16],
|
||||
[128, 128, 64, 2],
|
||||
[64, 128, 128, 2]]
|
||||
for allow_tf32 in [True]
|
||||
for col_a in [True, False]
|
||||
for col_b in [True, False]
|
||||
for in_dtype, out_dtype in [('int8', 'int8'),
|
||||
('float16', 'float16'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]]
|
||||
if not triton.language.semantic.gpu_has_mfma() else
|
||||
# MFMA Test Dot tests
|
||||
[(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype)
|
||||
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax']
|
||||
@@ -1273,6 +1282,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
if torch.version.hip is not None:
|
||||
# set capability to large number to jump over check below
|
||||
# check are not relevant to amd gpu, left them for smaller diff between test_core.py and test_core_amd.py tests
|
||||
if (M, N, K) == (128, 256, 32):
|
||||
pytest.skip("Out of resources")
|
||||
capability = (100, 100)
|
||||
if out_dtype is None:
|
||||
if in_dtype in float_dtypes:
|
||||
|
||||
Reference in New Issue
Block a user