mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
set correct arch info for unit test (#370)
* set correct arch info for unit test * address review comments
This commit is contained in:
@@ -2525,6 +2525,12 @@ class SharedLayout:
|
||||
return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
|
||||
|
||||
|
||||
def get_gpu_name():
|
||||
capabilities = triton.compiler.compiler.get_architecture_descriptor(None)
|
||||
gpu_name = capabilities[1].split(':')[0]
|
||||
return gpu_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize("vec_size", [2, 4])
|
||||
@pytest.mark.parametrize("swizzle", [True, False])
|
||||
@pytest.mark.parametrize("transposeA", [True, False])
|
||||
@@ -2534,6 +2540,9 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
|
||||
if transposeA and not transposeB:
|
||||
pytest.skip()
|
||||
|
||||
if triton.language.semantic.gpu_matrix_core_version() == 0:
|
||||
pytest.skip("mfma is not available on hardware")
|
||||
|
||||
# source code for following ttgir:
|
||||
# @triton.jit
|
||||
# def kernel(X, Y, Z):
|
||||
@@ -2617,7 +2626,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
arch_triple = "amdgcn-amd-amdhsa"
|
||||
arch_name = "gfx90a"
|
||||
arch_name = get_gpu_name()
|
||||
features = ""
|
||||
warp_size = 64
|
||||
capabilities = [arch_triple, arch_name, features, warp_size]
|
||||
|
||||
Reference in New Issue
Block a user