ROCM IFU: fix test_dot_mfma_vector_load test

fix for previous commit
This commit is contained in:
Alexander Efimov
2023-11-02 17:40:37 +00:00
committed by Jason Furmanek
parent 8bc417b9b7
commit aefc94bd25

View File

@@ -2718,11 +2718,6 @@ class SharedLayout:
return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
def get_gpu_name():
arch = triton.compiler.compiler.get_architecture_descriptor(None)
return arch["gfx_arch"]
@pytest.mark.parametrize("vec_size", [2, 4])
@pytest.mark.parametrize("swizzle", [True, False])
@pytest.mark.parametrize("transposeA", [True, False])
@@ -2795,7 +2790,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%21 = tt.load %13 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked>
%22 = triton_gpu.convert_layout %21 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared2>
%23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>>
%24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>> -> tensor<32x32xf32, #mfma>
%24 = tt.dot %20, %23, %cst {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>> -> tensor<32x32xf32, #mfma>
%25 = triton_gpu.convert_layout %24 : (tensor<32x32xf32, #mfma>) -> tensor<32x32xf32, #blocked>
%26 = arith.truncf %25 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked>
tt.store %17, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf16, #blocked>
@@ -2817,11 +2812,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
arch_triple = "amdgcn-amd-amdhsa"
arch_name = get_gpu_name()
features = ""
warp_size = 64
capabilities = [arch_triple, arch_name, features, warp_size]
backend = triton.common.backend.get_backend("hip")
capabilities = backend.get_architecture_descriptor()
kernel = triton.compile(f.name, device_type="hip", cc=capabilities)
import triton.language.semantic as sem