[BACKEND] Dedup elementwise in LLVM IR based on constancy (#2512)

### Summary

When Triton GPU IR is lowered into LLVM IR, we can make use of the
constancy information about the result of the elementwise ops to
deduplicate otherwise redundant computation. That is the contribution of
this PR: the constancy is checked and, if possible, some of the values
in LLVM IR are reused multiple times instead of computing equal values
separately.

The change is beneficial for the PyTorch 2 / TorchInductor-generated
Triton code, as the leftmost sub-indices extracted from the flat index
by div / mod operations can be equal, given sufficiently large 2^n
factor in the rightmost rightmost dimension(s). This makes the
computation resulting in those sub-indices redundant. Consequently,
under the necessary constancy conditions, the redundant indexing
arithmetics can be deduplicated. We observe up to 29% decrease in the
latency of some of our jagged tensor kernels
This commit is contained in:
Adnan Akhundov
2023-10-25 17:25:29 +02:00
committed by GitHub
parent e70e11e834
commit 7d55968fee
2 changed files with 216 additions and 23 deletions

View File

@@ -519,8 +519,118 @@ public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ElementwiseOpConversionBase(
TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
TritonGPUToLLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit),
axisAnalysisPass(axisAnalysisPass) {}
// Try to deduplicate the resultVals based on the
// constancy properties of the result discovered by
// the axis analysis pass. If possible, redundant
// computation is eliminated.
SmallVector<Value> maybeDeduplicate(SourceOp op,
SmallVector<Value> resultVals) const {
if (!isMemoryEffectFree(op))
// the op has side effects: can't dedup
return resultVals;
SmallVector<Value> results = op->getResults();
if (results.size() == 0 || results.size() > 1)
// there must be exactly 1 result
return resultVals;
Value result = results[0];
Type type = result.getType();
if (!type)
return resultVals;
RankedTensorType rtType = type.dyn_cast<RankedTensorType>();
if (!rtType)
// the result must be a tensor
return resultVals;
Attribute encoding = rtType.getEncoding();
if (!encoding)
// encoding not available
return resultVals;
if (!encoding.dyn_cast<triton::gpu::BlockedEncodingAttr>() &&
!encoding.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
// TODO: constraining the ecndoing type here is necessary
// for avoiding crashes in the triton::gpu::getElemsPerThread
// call below happening in the test_core::test_fp8_dot_acc
return resultVals;
}
SmallVector<unsigned> elemsPerThread =
triton::gpu::getElemsPerThread(rtType);
int rank = elemsPerThread.size();
if (product<unsigned>(elemsPerThread) != resultVals.size())
return resultVals;
AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result);
if (!axisInfo)
// axis info (e.g., constancy) not available
return resultVals;
SmallVector<unsigned> sizePerThread =
triton::gpu::getSizePerThread(encoding);
if (rank != sizePerThread.size())
return resultVals;
SmallVector<int64_t> constancy = axisInfo->getConstancy();
if (rank != constancy.size())
return resultVals;
bool hasConstancy = false;
for (int i = 0; i < rank; ++i) {
if (constancy[i] > sizePerThread[i]) {
if (constancy[i] % sizePerThread[i] != 0)
// constancy is not evenly covered by sizePerThread
return resultVals;
// can't move the values across different
// "sizePerThread"-sized blocks
constancy[i] = sizePerThread[i];
}
if (elemsPerThread[i] < 1 || constancy[i] < 1)
return resultVals;
if (!(elemsPerThread[i] % constancy[i] == 0 ||
constancy[i] % elemsPerThread[i] == 0))
// either the constancy along each dimension must fit
// into the elemsPerThread or the other way around
return resultVals;
if (constancy[i] > 1)
hasConstancy = true;
}
if (!hasConstancy)
// nothing to deduplicate
return resultVals;
if (rank > 1) {
// reorder the shape and constancy vectors by the axis order:
// from the fastest-changing to the smallest-changing axis
SmallVector<unsigned> order = triton::gpu::getOrder(encoding);
if (rank != order.size())
return resultVals;
ArrayRef<unsigned> orderRef(order);
elemsPerThread = reorder(ArrayRef<unsigned>(elemsPerThread), orderRef);
constancy = reorder(ArrayRef<int64_t>(constancy), orderRef);
}
SmallVector<unsigned> strides(rank, 1);
for (int i = 1; i < rank; ++i) {
strides[i] = strides[i - 1] * elemsPerThread[i - 1];
}
SmallVector<Value> dedupResultVals;
dedupResultVals.reserve(resultVals.size());
for (int i = 0; i < resultVals.size(); ++i) {
// each coordinate of the orig_idx is "coarsened" using the
// constancy along this dimension: the resulting dedup_idx
// points to the reused value in the original resultsVal
int orig_idx = i;
int dedup_idx = 0;
for (int j = 0; j < rank; ++j) {
int coord_j = orig_idx % elemsPerThread[j];
dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j];
orig_idx /= elemsPerThread[j];
}
dedupResultVals.push_back(resultVals[dedup_idx]);
}
return dedupResultVals;
}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
@@ -561,6 +671,7 @@ public:
auto argTy = op->getOperand(0).getType();
resultVals = reorderValues(resultVals, argTy, resultTy);
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
@@ -570,6 +681,9 @@ public:
return success();
}
protected:
ModuleAxisInfoAnalysis &axisAnalysisPass;
private:
int computeCapability;
};
@@ -601,8 +715,9 @@ struct FpToFpOpConversion
triton::FpToFpOp, FpToFpOpConversion>::ElementwiseOpConversionBase;
explicit FpToFpOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisAnalysisPass,
int computeCapability, PatternBenefit benefit = 1)
: ElementwiseOpConversionBase(typeConverter, benefit),
: ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit),
computeCapability(computeCapability) {}
static Value convertBf16ToFp32(Location loc,
@@ -1313,12 +1428,14 @@ void populateElementwiseOpToLLVMPatterns(
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
int computeCapability, PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);
POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp)
#undef POPULATE_TERNARY_OP
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // *
@@ -1342,7 +1459,8 @@ void populateElementwiseOpToLLVMPatterns(
#undef POPULATE_BINARY_OP
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
@@ -1358,29 +1476,32 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP
patterns.add<AbsIOpConversion>(typeConverter, benefit);
patterns.add<AbsFOpConversion>(typeConverter, benefit);
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);
patterns.add<AbsIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AbsFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<CmpIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<CmpFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FDivOpConversion>(typeConverter, benefit);
patterns.add<FSubOpConversion>(typeConverter, benefit);
patterns.add<FAddOpConversion>(typeConverter, benefit);
patterns.add<FMulOpConversion>(typeConverter, benefit);
patterns.add<FDivOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FSubOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FAddOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FMulOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<ExtFOpConversion>(typeConverter, benefit);
patterns.add<TruncFOpConversion>(typeConverter, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, benefit);
patterns.add<SIToFPOpConversion>(typeConverter, benefit);
patterns.add<IndexCastOpLowering>(typeConverter, benefit);
patterns.add<ExtFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<TruncFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<SIToFPOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<IndexCastOpLowering>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FpToFpOpConversion>(typeConverter, computeCapability, benefit);
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis,
computeCapability, benefit);
patterns.add<ExternElementwiseOpConversion>(typeConverter, benefit);
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
patterns.add<ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,
benefit);
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter,
axisInfoAnalysis, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is
// FP32. For other input types, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit);
}

View File

@@ -0,0 +1,72 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" --llvm-optimize-for-nvvm-target | FileCheck %s
// CHECK-LABEL: dedup_by_constancy_full
// CHECK-COUNT-5: llvm.add
// CHECK-NOT: llvm.add
// CHECK: llvm.icmp "slt"
// CHECK-NOT: llvm.icmp "slt"
// CHECK: llvm.sdiv
// CHECK-NOT: llvm.sdiv
// CHECK: llvm.getelementptr %arg0[[[REGISTER:%[0-9]+]]]
// CHECK-COUNT-7: llvm.getelementptr %arg0[[[REGISTER]]]
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER]]]
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @dedup_by_constancy_full(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<256> : tensor<1024xi32, #blocked>
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
%7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked>
%8 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
%10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked>
%11 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked>
tt.return
}
}
// -----
// CHECK-LABEL: dedup_by_constancy_partial
// CHECK-COUNT-8: llvm.add
// CHECK-NOT: llvm.add
// CHECK: llvm.icmp "slt"
// CHECK-NOT: llvm.icmp "slt"
// CHECK-COUNT-2: llvm.sdiv
// CHECK-NOT: llvm.sdiv
// CHECK: llvm.getelementptr %arg0[[[REGISTER1:%[0-9]+]]]
// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER1]]]
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER1]]]
// CHECK: llvm.getelementptr %arg0[[[REGISTER2:%[0-9]+]]]
// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER2]]]
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER2]]]
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @dedup_by_constancy_partial(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<4> : tensor<1024xi32, #blocked>
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
%7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked>
%8 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
%10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked>
%11 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked>
tt.return
}
}