mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Change the dot to allow taking an initial accumulator and add a flag that will allow the compiler to accumulate in a lower precision than the output type. On Hopper this flag is on by default which allows accumualting with lower precision. This only affect Hopper fp8 dot.
42 lines
2.1 KiB
MLIR
42 lines
2.1 KiB
MLIR
// RUN: triton-opt %s -split-input-file -verify-diagnostics
|
|
|
|
#mma0 = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[1,1]}>
|
|
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
|
|
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
tt.func @convert_dot(%A: tensor<16x16xf32, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) {
|
|
// expected-error@+1 {{element types of operands A and B must have same bit width}}
|
|
%D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} :
|
|
tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
|
|
tt.return
|
|
}
|
|
}
|
|
|
|
// -----
|
|
|
|
#mma0 = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[1,1]}>
|
|
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=1}>
|
|
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
tt.func @convert_dot(%A: tensor<16x16xf16>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) {
|
|
// expected-error@+1 {{mismatching encoding between A and B operands}}
|
|
%D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} :
|
|
tensor<16x16xf16> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
|
|
tt.return
|
|
}
|
|
}
|
|
|
|
// -----
|
|
|
|
#mma0 = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[1,1]}>
|
|
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=1}>
|
|
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) {
|
|
// expected-error@+1 {{mismatching kWidth between A and B operands}}
|
|
%D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} :
|
|
tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
|
|
tt.return
|
|
}
|
|
}
|