[BACKEND] Add a configurable parameter for the number of threads per warp (#1719)

Add a configurable parameter for the number of threads per warp for
other GPU. Like: Intel GPU.

Make it default to be 32 not change code logic on the CUDA/AMD GPU.

Note: The Intel GPU GenX ISA is explicit SIMD and can support variant
number of threads lane per HW execution unit.
This commit is contained in:
chengjunlu
2023-06-03 07:55:06 +08:00
committed by GitHub
parent 035381aa28
commit 45ba9af6ed
12 changed files with 81 additions and 40 deletions

View File

@@ -20,7 +20,11 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">
"number of warps">,
Option<"threadsPerWarp", "threads-per-warp",
"int32_t", /*default*/"32",
"number of threads per warp">,
];
}

View File

@@ -12,12 +12,14 @@ namespace triton {
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp";
// Create the pass with numWarps passed from cl::opt.
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps);
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32);
} // namespace triton
} // namespace mlir

View File

@@ -229,31 +229,32 @@ for
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps), [{
"unsigned":$numWarps,
"unsigned":$threadsPerWarp), [{
int rank = sizePerThread.size();
unsigned remainingLanes = 32;
unsigned remainingThreads = numWarps*32;
unsigned remainingLanes = threadsPerWarp;
unsigned remainingThreads = numWarps*threadsPerWarp;
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> rankedThreadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank - 1; ++_dim) {
int i = order[_dim];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
rankedThreadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / rankedThreadsPerWarp[i], 1, remainingWarps);
remainingWarps /= warpsPerCTA[i];
remainingLanes /= threadsPerWarp[i];
remainingLanes /= rankedThreadsPerWarp[i];
remainingThreads /= threadsPerCTA;
prevLanes *= threadsPerWarp[i];
prevLanes *= rankedThreadsPerWarp[i];
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
rankedThreadsPerWarp[order[rank-1]] = threadsPerWarp / prevLanes;
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
return $_get(context, sizePerThread, rankedThreadsPerWarp, warpsPerCTA, order);
}]>
];

View File

@@ -29,6 +29,16 @@ def TritonGPU_Dialect : Dialect {
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return numWarps.cast<IntegerAttr>().getInt();
}
static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; }
static int getThreadsPerWarp(ModuleOp mod) {
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp");
if(!threadsPerWarp) {
return 32;
}
return threadsPerWarp.cast<IntegerAttr>().getInt();
}
}];
let useDefaultAttributePrinterParser = 1;

View File

@@ -13,12 +13,15 @@ namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
TritonGPUTypeConverter(MLIRContext *context, int numWarps,
int threadsPerWarp);
int getNumWarps() const { return numWarps; }
int getThreadsPerWarp() const { return threadsPerWarp; }
private:
MLIRContext *context;
int numWarps;
int threadsPerWarp;
};
class TritonGPUConversionTarget : public ConversionTarget {

View File

@@ -72,7 +72,9 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
/// shared memory block1:
auto mod = op->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32);
unsigned threadsPerWarp =
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
smemShapes[1].push_back(numWarps * threadsPerWarp);
return smemShapes;
}

View File

@@ -303,9 +303,10 @@ public:
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget target(*context, isROCM);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
// Preprocess
decomposeMmaToDotOperand(mod, numWarps);
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
decomposeBlockedToDotOperand(mod);
decomposeInsertSliceAsyncOp(mod);
@@ -432,7 +433,8 @@ private:
allocation.getSharedMemorySize()));
}
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) const {
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps,
int threadsPerWarp) const {
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
// unless certain conditions are met
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
@@ -448,7 +450,7 @@ private:
dstType.getShape(), dstType.getElementType(),
triton::gpu::BlockedEncodingAttr::get(
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
getOrder(srcMma), numWarps));
getOrder(srcMma), numWarps, threadsPerWarp));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(

View File

@@ -254,15 +254,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
auto origShape = origType.getShape();
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
int numWarps = typeConverter->getNumWarps();
int threadsPerWarp = typeConverter->getThreadsPerWarp();
SmallVector<unsigned> retSizePerThread = {1, 1};
if (origShape[0] * origShape[1] / (numWarps * 32) >= 4)
if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 4)
retSizePerThread = {2, 2};
if (origShape[0] * origShape[1] / (numWarps * 32) >= 16)
if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 16)
retSizePerThread = {4, 4};
SmallVector<unsigned> retOrder = {1, 0};
Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
getContext(), origShape, retSizePerThread, retOrder, numWarps);
getContext(), origShape, retSizePerThread, retOrder, numWarps,
threadsPerWarp);
RankedTensorType retType =
RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
// a & b must be of smem layout
@@ -806,13 +808,16 @@ class ConvertTritonToTritonGPU
public:
ConvertTritonToTritonGPU() = default;
// constructor with some parameters set explicitly.
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
ConvertTritonToTritonGPU(int numWarps, int threadsPerWarp) {
this->numWarps = numWarps;
this->threadsPerWarp = threadsPerWarp;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
// type converter
TritonGPUTypeConverter typeConverter(context, numWarps);
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns
RewritePatternSet patterns(context);
@@ -835,6 +840,9 @@ public:
mod->setAttr(
AttrNumWarpsName,
IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue())));
mod->setAttr(
AttrNumThreadsPerWarp,
IntegerAttr::get(i32_ty, llvm::APInt(32, threadsPerWarp.getValue())));
// update layouts
// broadcast src => multicast, dst => broadcasted
@@ -846,8 +854,9 @@ public:
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps,
int threadsPerWarp) {
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps, threadsPerWarp);
}
std::unique_ptr<OperationPass<ModuleOp>>

View File

@@ -23,7 +23,7 @@ typedef DenseMap<Value, std::function<Type(Type)>> LayoutMap;
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Attribute getCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis,
Value ptr, int numWarps) {
Value ptr, int numWarps, int threadsPerWarp) {
auto origType = ptr.getType().cast<RankedTensorType>();
// Get the shape of the tensor.
size_t rank = origType.getRank();
@@ -46,7 +46,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
}
}
int numElems = product(origType.getShape());
int numThreads = numWarps * 32;
int numThreads = numWarps * threadsPerWarp;
int numElemsPerThread = std::max(numElems / numThreads, 1);
// Thread tile size depends on memory alignment
SmallVector<unsigned, 4> sizePerThread(rank, 1);
@@ -68,14 +68,16 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
std::iota(dims.begin(), dims.end(), 0);
// create encoding
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
&getContext(), origType.getShape(), sizePerThread, order, numWarps);
&getContext(), origType.getShape(), sizePerThread, order, numWarps,
threadsPerWarp);
return encoding;
}
std::function<Type(Type)>
getTypeConverter(ModuleAxisInfoAnalysis &axisInfoAnalysis, Value ptr,
int numWarps) {
Attribute encoding = getCoalescedEncoding(axisInfoAnalysis, ptr, numWarps);
int numWarps, int threadsPerWarp) {
Attribute encoding =
getCoalescedEncoding(axisInfoAnalysis, ptr, numWarps, threadsPerWarp);
return [encoding](Type _type) {
RankedTensorType type = _type.cast<RankedTensorType>();
return RankedTensorType::get(type.getShape(), type.getElementType(),
@@ -148,7 +150,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
return;
auto mod = curr->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto convertType = getTypeConverter(axisInfoAnalysis, ptr, numWarps);
int threadsPerWarp =
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
auto convertType =
getTypeConverter(axisInfoAnalysis, ptr, numWarps, threadsPerWarp);
layoutMap[ptr] = convertType;
});

View File

@@ -12,8 +12,8 @@ using namespace mlir::triton::gpu;
// TypeConverter
//
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numWarps)
: context(context), numWarps(numWarps) {
int numWarps, int threadsPerWarp)
: context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp) {
addConversion([](Type type) { return type; });
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
// types with encoding are already in the right format
@@ -29,7 +29,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
std::iota(order.begin(), order.end(), 0);
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
this->context, shape, sizePerThread, order, this->numWarps);
this->context, shape, sizePerThread, order, this->numWarps,
this->threadsPerWarp);
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
});

View File

@@ -1532,11 +1532,13 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::triton::createRewriteTensorPointerPass(
computeCapability));
})
.def("add_convert_triton_to_tritongpu_pass",
[](mlir::PassManager &self, int numWarps) {
self.addPass(
mlir::triton::createConvertTritonToTritonGPUPass(numWarps));
})
.def(
"add_convert_triton_to_tritongpu_pass",
[](mlir::PassManager &self, int numWarps, int threadsPerWarp) {
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass(
numWarps, threadsPerWarp));
},
py::arg("numWarps") = 4, py::arg("threadsPerWarp") = 32)
.def("add_tritongpu_pipeline_pass",
[](mlir::PassManager &self, int numStages) {
self.addPass(mlir::createTritonGPUPipelinePass(numStages));

View File

@@ -1,7 +1,7 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
tt.func @ops() {
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
@@ -36,7 +36,7 @@ tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>