mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Add a configurable parameter for the number of threads per warp (#1719)
Add a configurable parameter for the number of threads per warp for other GPU. Like: Intel GPU. Make it default to be 32 not change code logic on the CUDA/AMD GPU. Note: The Intel GPU GenX ISA is explicit SIMD and can support variant number of threads lane per HW execution unit.
This commit is contained in:
@@ -20,7 +20,11 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
||||
let options = [
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps">
|
||||
"number of warps">,
|
||||
|
||||
Option<"threadsPerWarp", "threads-per-warp",
|
||||
"int32_t", /*default*/"32",
|
||||
"number of threads per warp">,
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
@@ -12,12 +12,14 @@ namespace triton {
|
||||
|
||||
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
|
||||
|
||||
constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp";
|
||||
|
||||
// Create the pass with numWarps passed from cl::opt.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
|
||||
|
||||
// Create the pass with numWarps set explicitly.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUPass(int numWarps);
|
||||
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -229,31 +229,32 @@ for
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps), [{
|
||||
"unsigned":$numWarps,
|
||||
"unsigned":$threadsPerWarp), [{
|
||||
int rank = sizePerThread.size();
|
||||
unsigned remainingLanes = 32;
|
||||
unsigned remainingThreads = numWarps*32;
|
||||
unsigned remainingLanes = threadsPerWarp;
|
||||
unsigned remainingThreads = numWarps*threadsPerWarp;
|
||||
unsigned remainingWarps = numWarps;
|
||||
unsigned prevLanes = 1;
|
||||
unsigned prevWarps = 1;
|
||||
SmallVector<unsigned, 4> threadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> rankedThreadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
for (int _dim = 0; _dim < rank - 1; ++_dim) {
|
||||
int i = order[_dim];
|
||||
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
|
||||
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
|
||||
rankedThreadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / rankedThreadsPerWarp[i], 1, remainingWarps);
|
||||
remainingWarps /= warpsPerCTA[i];
|
||||
remainingLanes /= threadsPerWarp[i];
|
||||
remainingLanes /= rankedThreadsPerWarp[i];
|
||||
remainingThreads /= threadsPerCTA;
|
||||
prevLanes *= threadsPerWarp[i];
|
||||
prevLanes *= rankedThreadsPerWarp[i];
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
|
||||
rankedThreadsPerWarp[order[rank-1]] = threadsPerWarp / prevLanes;
|
||||
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
return $_get(context, sizePerThread, rankedThreadsPerWarp, warpsPerCTA, order);
|
||||
|
||||
}]>
|
||||
];
|
||||
|
||||
@@ -29,6 +29,16 @@ def TritonGPU_Dialect : Dialect {
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
return numWarps.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; }
|
||||
static int getThreadsPerWarp(ModuleOp mod) {
|
||||
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp");
|
||||
if(!threadsPerWarp) {
|
||||
return 32;
|
||||
}
|
||||
return threadsPerWarp.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
}];
|
||||
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
|
||||
@@ -13,12 +13,15 @@ namespace mlir {
|
||||
|
||||
class TritonGPUTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numWarps,
|
||||
int threadsPerWarp);
|
||||
int getNumWarps() const { return numWarps; }
|
||||
int getThreadsPerWarp() const { return threadsPerWarp; }
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
int numWarps;
|
||||
int threadsPerWarp;
|
||||
};
|
||||
|
||||
class TritonGPUConversionTarget : public ConversionTarget {
|
||||
|
||||
@@ -72,7 +72,9 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
/// shared memory block1:
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
smemShapes[1].push_back(numWarps * 32);
|
||||
unsigned threadsPerWarp =
|
||||
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
||||
smemShapes[1].push_back(numWarps * threadsPerWarp);
|
||||
|
||||
return smemShapes;
|
||||
}
|
||||
|
||||
@@ -303,9 +303,10 @@ public:
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMConversionTarget target(*context, isROCM);
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
||||
|
||||
// Preprocess
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
|
||||
@@ -432,7 +433,8 @@ private:
|
||||
allocation.getSharedMemorySize()));
|
||||
}
|
||||
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) const {
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps,
|
||||
int threadsPerWarp) const {
|
||||
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
|
||||
// unless certain conditions are met
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
@@ -448,7 +450,7 @@ private:
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::BlockedEncodingAttr::get(
|
||||
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
|
||||
getOrder(srcMma), numWarps));
|
||||
getOrder(srcMma), numWarps, threadsPerWarp));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
|
||||
@@ -254,15 +254,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
auto origShape = origType.getShape();
|
||||
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
|
||||
int numWarps = typeConverter->getNumWarps();
|
||||
int threadsPerWarp = typeConverter->getThreadsPerWarp();
|
||||
|
||||
SmallVector<unsigned> retSizePerThread = {1, 1};
|
||||
if (origShape[0] * origShape[1] / (numWarps * 32) >= 4)
|
||||
if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 4)
|
||||
retSizePerThread = {2, 2};
|
||||
if (origShape[0] * origShape[1] / (numWarps * 32) >= 16)
|
||||
if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 16)
|
||||
retSizePerThread = {4, 4};
|
||||
SmallVector<unsigned> retOrder = {1, 0};
|
||||
Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
|
||||
getContext(), origShape, retSizePerThread, retOrder, numWarps);
|
||||
getContext(), origShape, retSizePerThread, retOrder, numWarps,
|
||||
threadsPerWarp);
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
|
||||
// a & b must be of smem layout
|
||||
@@ -806,13 +808,16 @@ class ConvertTritonToTritonGPU
|
||||
public:
|
||||
ConvertTritonToTritonGPU() = default;
|
||||
// constructor with some parameters set explicitly.
|
||||
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
|
||||
ConvertTritonToTritonGPU(int numWarps, int threadsPerWarp) {
|
||||
this->numWarps = numWarps;
|
||||
this->threadsPerWarp = threadsPerWarp;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, numWarps);
|
||||
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp);
|
||||
TritonGPUConversionTarget target(*context, typeConverter);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
@@ -835,6 +840,9 @@ public:
|
||||
mod->setAttr(
|
||||
AttrNumWarpsName,
|
||||
IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue())));
|
||||
mod->setAttr(
|
||||
AttrNumThreadsPerWarp,
|
||||
IntegerAttr::get(i32_ty, llvm::APInt(32, threadsPerWarp.getValue())));
|
||||
|
||||
// update layouts
|
||||
// broadcast src => multicast, dst => broadcasted
|
||||
@@ -846,8 +854,9 @@ public:
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
|
||||
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
|
||||
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps,
|
||||
int threadsPerWarp) {
|
||||
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps, threadsPerWarp);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
||||
@@ -23,7 +23,7 @@ typedef DenseMap<Value, std::function<Type(Type)>> LayoutMap;
|
||||
|
||||
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
Attribute getCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
Value ptr, int numWarps) {
|
||||
Value ptr, int numWarps, int threadsPerWarp) {
|
||||
auto origType = ptr.getType().cast<RankedTensorType>();
|
||||
// Get the shape of the tensor.
|
||||
size_t rank = origType.getRank();
|
||||
@@ -46,7 +46,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
}
|
||||
}
|
||||
int numElems = product(origType.getShape());
|
||||
int numThreads = numWarps * 32;
|
||||
int numThreads = numWarps * threadsPerWarp;
|
||||
int numElemsPerThread = std::max(numElems / numThreads, 1);
|
||||
// Thread tile size depends on memory alignment
|
||||
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
||||
@@ -68,14 +68,16 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
// create encoding
|
||||
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
||||
&getContext(), origType.getShape(), sizePerThread, order, numWarps);
|
||||
&getContext(), origType.getShape(), sizePerThread, order, numWarps,
|
||||
threadsPerWarp);
|
||||
return encoding;
|
||||
}
|
||||
|
||||
std::function<Type(Type)>
|
||||
getTypeConverter(ModuleAxisInfoAnalysis &axisInfoAnalysis, Value ptr,
|
||||
int numWarps) {
|
||||
Attribute encoding = getCoalescedEncoding(axisInfoAnalysis, ptr, numWarps);
|
||||
int numWarps, int threadsPerWarp) {
|
||||
Attribute encoding =
|
||||
getCoalescedEncoding(axisInfoAnalysis, ptr, numWarps, threadsPerWarp);
|
||||
return [encoding](Type _type) {
|
||||
RankedTensorType type = _type.cast<RankedTensorType>();
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
@@ -148,7 +150,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
return;
|
||||
auto mod = curr->getParentOfType<ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
auto convertType = getTypeConverter(axisInfoAnalysis, ptr, numWarps);
|
||||
int threadsPerWarp =
|
||||
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
||||
auto convertType =
|
||||
getTypeConverter(axisInfoAnalysis, ptr, numWarps, threadsPerWarp);
|
||||
layoutMap[ptr] = convertType;
|
||||
});
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ using namespace mlir::triton::gpu;
|
||||
// TypeConverter
|
||||
//
|
||||
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
int numWarps)
|
||||
: context(context), numWarps(numWarps) {
|
||||
int numWarps, int threadsPerWarp)
|
||||
: context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp) {
|
||||
addConversion([](Type type) { return type; });
|
||||
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
|
||||
// types with encoding are already in the right format
|
||||
@@ -29,7 +29,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
|
||||
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
||||
this->context, shape, sizePerThread, order, this->numWarps);
|
||||
this->context, shape, sizePerThread, order, this->numWarps,
|
||||
this->threadsPerWarp);
|
||||
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
|
||||
});
|
||||
|
||||
|
||||
@@ -1532,11 +1532,13 @@ void init_triton_ir(py::module &&m) {
|
||||
self.addPass(mlir::triton::createRewriteTensorPointerPass(
|
||||
computeCapability));
|
||||
})
|
||||
.def("add_convert_triton_to_tritongpu_pass",
|
||||
[](mlir::PassManager &self, int numWarps) {
|
||||
self.addPass(
|
||||
mlir::triton::createConvertTritonToTritonGPUPass(numWarps));
|
||||
})
|
||||
.def(
|
||||
"add_convert_triton_to_tritongpu_pass",
|
||||
[](mlir::PassManager &self, int numWarps, int threadsPerWarp) {
|
||||
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass(
|
||||
numWarps, threadsPerWarp));
|
||||
},
|
||||
py::arg("numWarps") = 4, py::arg("threadsPerWarp") = 32)
|
||||
.def("add_tritongpu_pipeline_pass",
|
||||
[](mlir::PassManager &self, int numStages) {
|
||||
self.addPass(mlir::createTritonGPUPipelinePass(numStages));
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
||||
|
||||
tt.func @ops() {
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
|
||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
||||
@@ -36,7 +36,7 @@ tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
|
||||
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
|
||||
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
|
||||
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>
|
||||
|
||||
Reference in New Issue
Block a user