[TEST] made lut_bmm pipeline test more concise and specific (#1488)

This commit is contained in:
Philippe Tillet
2023-04-08 19:17:35 -07:00
committed by GitHub
parent f7ad8ae022
commit b86425a28e
3 changed files with 66 additions and 128 deletions

View File

@@ -54,6 +54,10 @@ public:
std::optional<int64_t> getConstantValue() const { return constantValue; }
template <class T>
static void
initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity,
DimVectorT *divisibility, DimVectorT *constancy);
/// Comparison
bool operator==(const AxisInfo &other) const {
return (contiguity == other.contiguity) &&

View File

@@ -42,27 +42,52 @@ static constexpr int log2Int(int64_t num) {
// AxisInfo
//===----------------------------------------------------------------------===//
template <class T>
void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
DimVectorT *contiguity,
DimVectorT *divisibility,
DimVectorT *constancy) {
// liast of attributes that we care about
SmallVector<std::pair<DimVectorT *, std::string>> retVecs;
retVecs.push_back({contiguity, "tt.contiguity"});
retVecs.push_back({divisibility, "tt.divisibility"});
retVecs.push_back({constancy, "tt.constancy"});
// initialize attributes one by one
for (auto [vec, attrName] : retVecs) {
Attribute attr = funcOp.getArgAttr(argNumber, attrName);
if (auto int_attr = attr.dyn_cast_or_null<IntegerAttr>())
*vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue());
if (auto dense_attr = attr.dyn_cast_or_null<DenseElementsAttr>()) {
auto vals = dense_attr.getValues<int>();
*vec = DimVectorT(vals.begin(), vals.end());
}
}
}
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
auto rank = 1;
if (TensorType ty = value.getType().dyn_cast<TensorType>())
rank = ty.getRank();
auto contiHint = 1;
auto divHint = 1;
auto constHint = 1;
DimVectorT knownContiguity(rank, 1);
DimVectorT knownDivisibility(rank, 1);
DimVectorT knownConstancy(rank, 1);
BlockArgument blockArg = value.dyn_cast<BlockArgument>();
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
Operation *op = blockArg.getOwner()->getParentOp();
if (func::FuncOp fun = dyn_cast<func::FuncOp>(op)) {
Attribute attr =
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if (attr)
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
} else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op)) {
Attribute attr =
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if (attr)
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
} else {
if (auto fun = dyn_cast<func::FuncOp>(op))
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
// llvm codegen check alignment to generate vector load/store
// would be nice if this wasn't the case
else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op))
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
else {
// Derive the divisibility of the induction variable only when
// the step and the lower bound are both constants
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
@@ -79,16 +104,13 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
step.getValue().cast<IntegerAttr>().getValue().getZExtValue();
auto k = gcd(lowerBoundVal, stepVal);
if (k != 0)
divHint = k;
knownDivisibility = DimVectorT(rank, k);
}
}
}
}
}
} else if (Operation *op = value.getDefiningOp()) {
DimVectorT knownContiguity(rank, 1);
DimVectorT knownDivisibility(rank, 1);
DimVectorT knownConstancy(rank, 1);
if (Attribute attr = op->getAttr("tt.divisibility")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownDivisibility = DimVectorT(vals.begin(), vals.end());
@@ -101,12 +123,9 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownConstancy = DimVectorT(vals.begin(), vals.end());
}
return AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
}
return AxisInfo(/*knownContiguity=*/DimVectorT(rank, contiHint),
/*knownDivisibility=*/DimVectorT(rank, divHint),
/*knownConstancy=*/DimVectorT(rank, constHint));
return AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
}
// The gcd of both arguments for each dimension

View File

@@ -230,124 +230,39 @@ func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
// CHECK: triton_gpu.insert_slice_async
// CHECK: triton_gpu.insert_slice_async
// CHECK: triton_gpu.async_commit_group
// CHECK: %[[LUT_PTR:.*]] = tt.addptr
// CHECK: %arg27 = %[[LUT_PTR]]
// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %arg27, {{.*}}
// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %arg15, {{.*}}
// CHECK: %[[LUT_BUFFER_1:.*]] = arith.muli {{.*}}, %[[LUT_BUFFER_0]]
// CHECK: %[[LUT_BUFFER_2:.*]] = tt.splat %[[LUT_BUFFER_1]]
// CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[LUT_BUFFER_2]]
// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %arg26, {{.*}}
// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %arg14, {{.*}}
// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]]
// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]]
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}>
func.func @lut_bmm(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
// CHECK: triton_gpu.async_wait {num = 2 : i32}
func.func @lut_bmm(%77: i64 {tt.divisibility=16: i32},
%76: index,
%49: tensor<16x16x!tt.ptr<f16>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%75: !tt.ptr<i64>,
%78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32},
%60: tensor<16x16x!tt.ptr<f16>, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C>
%c4_i32 = arith.constant 4 : i32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c0_i64 = arith.constant 0 : i64
%c1_i32 = arith.constant 1 : i32
%0 = tt.get_program_id {axis = 2 : i32} : i32
%1 = tt.get_program_id {axis = 0 : i32} : i32
%2 = tt.get_program_id {axis = 1 : i32} : i32
%3 = tt.get_num_programs {axis = 0 : i32} : i32
%4 = tt.get_num_programs {axis = 1 : i32} : i32
%5 = arith.muli %1, %4 : i32
%6 = arith.addi %5, %2 : i32
%7 = arith.muli %4, %c4_i32 : i32
%8 = arith.divsi %6, %7 : i32
%9 = arith.muli %8, %c4_i32 : i32
%10 = arith.subi %3, %9 : i32
%11 = arith.cmpi slt, %10, %c4_i32 : i32
%12 = arith.select %11, %10, %c4_i32 : i32
%13 = arith.remsi %6, %12 : i32
%14 = arith.addi %9, %13 : i32
%15 = arith.remsi %6, %7 : i32
%16 = arith.divsi %15, %12 : i32
%17 = arith.muli %arg5, %0 : i32
%18 = tt.addptr %arg4, %17 : !tt.ptr<i64>, i32
%19 = tt.addptr %18, %14 : !tt.ptr<i64>, i32
%20 = tt.load %19 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i64
%21 = tt.addptr %19, %c1_i32 : !tt.ptr<i64>, i32
%22 = tt.load %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i64
%23 = arith.subi %22, %20 : i64
%24 = arith.cmpi eq, %23, %c0_i64 : i64
cf.cond_br %24, ^bb1, ^bb2
^bb1: // pred: ^bb0
return
^bb2: // pred: ^bb0
%25 = arith.muli %arg1, %0 : i32
%26 = tt.addptr %arg0, %25 : !tt.ptr<f16>, i32
%27 = arith.extsi %arg2 : i32 to i64
%28 = arith.muli %27, %20 : i64
%29 = tt.addptr %26, %28 : !tt.ptr<f16>, i64
%30 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%31 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%32 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%33 = tt.expand_dims %30 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16x1xi32, #blocked>
%34 = tt.expand_dims %31 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xi32, #blocked1>
%35 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16x1xi32, #blocked>
%36 = tt.splat %arg3 : (i32) -> tensor<16x1xi32, #blocked>
%37 = arith.muli %36, %33 : tensor<16x1xi32, #blocked>
%38 = tt.splat %29 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked>
%39 = tt.addptr %38, %37 : tensor<16x1x!tt.ptr<f16>, #blocked>, tensor<16x1xi32, #blocked>
%40 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%41 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%42 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%43 = tt.expand_dims %40 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x16xi32, #blocked>
%44 = tt.expand_dims %41 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x16xi32, #blocked1>
%45 = tt.expand_dims %42 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x16xi32, #blocked>
%46 = tt.broadcast %39 : (tensor<16x1x!tt.ptr<f16>, #blocked>) -> tensor<16x16x!tt.ptr<f16>, #blocked>
%47 = tt.broadcast %43 : (tensor<1x16xi32, #blocked>) -> tensor<16x16xi32, #blocked>
%48 = tt.broadcast %45 : (tensor<1x16xi32, #blocked>) -> tensor<16x16xi32, #blocked>
%49 = tt.addptr %46, %47 : tensor<16x16x!tt.ptr<f16>, #blocked>, tensor<16x16xi32, #blocked>
%50 = arith.muli %arg9, %0 : i32
%51 = tt.addptr %arg8, %50 : !tt.ptr<f16>, i32
%52 = arith.muli %arg11, %16 : i32
%53 = tt.addptr %51, %52 : !tt.ptr<f16>, i32
%54 = tt.splat %53 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked1>
%55 = tt.addptr %54, %34 : tensor<16x1x!tt.ptr<f16>, #blocked1>, tensor<16x1xi32, #blocked1>
%56 = tt.splat %arg12 : (i32) -> tensor<1x16xi32, #blocked1>
%57 = arith.muli %56, %44 : tensor<1x16xi32, #blocked1>
%58 = tt.broadcast %55 : (tensor<16x1x!tt.ptr<f16>, #blocked1>) -> tensor<16x16x!tt.ptr<f16>, #blocked1>
%59 = tt.broadcast %57 : (tensor<1x16xi32, #blocked1>) -> tensor<16x16xi32, #blocked1>
%60 = tt.addptr %58, %59 : tensor<16x16x!tt.ptr<f16>, #blocked1>, tensor<16x16xi32, #blocked1>
%61 = arith.muli %arg14, %0 : i32
%62 = tt.addptr %arg13, %61 : !tt.ptr<f16>, i32
%63 = arith.muli %arg15, %14 : i32
%64 = tt.addptr %62, %63 : !tt.ptr<f16>, i32
%65 = arith.muli %arg16, %16 : i32
%66 = tt.addptr %64, %65 : !tt.ptr<f16>, i32
%67 = tt.splat %arg17 : (i32) -> tensor<16x1xi32, #blocked>
%68 = arith.muli %67, %35 : tensor<16x1xi32, #blocked>
%69 = tt.splat %66 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked>
%70 = tt.addptr %69, %68 : tensor<16x1x!tt.ptr<f16>, #blocked>, tensor<16x1xi32, #blocked>
%71 = tt.broadcast %70 : (tensor<16x1x!tt.ptr<f16>, #blocked>) -> tensor<16x16x!tt.ptr<f16>, #blocked>
%72 = tt.addptr %71, %48 : tensor<16x16x!tt.ptr<f16>, #blocked>, tensor<16x16xi32, #blocked>
%73 = arith.muli %arg7, %0 : i32
%74 = tt.addptr %arg6, %73 : !tt.ptr<i64>, i32
%75 = tt.addptr %74, %20 : !tt.ptr<i64>, i64
%76 = arith.index_cast %23 : i64 to index
%77 = arith.extsi %arg10 : i32 to i64
%78 = tt.splat %arg2 : (i32) -> tensor<16x16xi32, #blocked>
%79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr<f16>, #blocked>, !tt.ptr<i64>) {
%82 = tt.load %arg20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #blocked>
%79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr<f16>, #AL>, !tt.ptr<i64>) {
%82 = tt.load %arg20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL>
%83 = tt.load %arg21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i64
%84 = arith.muli %77, %83 : i64
%85 = tt.splat %84 : (i64) -> tensor<16x16xi64, #blocked1>
%86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr<f16>, #blocked1>, tensor<16x16xi64, #blocked1>
%87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #blocked1>
%88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
%89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #blocked1>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
%90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<16x16xf32, #mma>
%91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr<f16>, #blocked>, tensor<16x16xi32, #blocked>
%85 = tt.splat %84 : (i64) -> tensor<16x16xi64, #BL>
%86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr<f16>, #BL>, tensor<16x16xi64, #BL>
%87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
%88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A>
%89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B>
%90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C>
%91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr<f16>, #AL>, tensor<16x16xi32, #AL>
%92 = tt.addptr %arg21, %c1_i32 : !tt.ptr<i64>, i32
scf.yield %90, %91, %92 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr<f16>, #blocked>, !tt.ptr<i64>
scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr<f16>, #AL>, !tt.ptr<i64>
}
%80 = arith.truncf %79#0 : tensor<16x16xf32, #mma> to tensor<16x16xf16, #mma>
%81 = triton_gpu.convert_layout %80 : (tensor<16x16xf16, #mma>) -> tensor<16x16xf16, #blocked>
tt.store %72, %81 {cache = 1 : i32, evict = 1 : i32} : tensor<16x16xf16, #blocked>
return
return %79#0 : tensor<16x16xf32, #C>
}