mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108
Conflicts: bin/triton-translate.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp python/triton/compiler/compiler.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -885,7 +885,8 @@ public:
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
|
||||
: dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) {
|
||||
: dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(
|
||||
solver) {
|
||||
// UnrealizedConversionCast:
|
||||
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
||||
// in the process of a PartialConversion, where UnrealizedConversionCast
|
||||
|
||||
@@ -38,6 +38,17 @@ bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
|
||||
getParentOrder(getSrcLayout())[0];
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
|
||||
auto srcLayout = getSrcLayout();
|
||||
auto order = triton::gpu::getOrder(srcLayout);
|
||||
auto it = std::find(order.begin(), order.end(), axis);
|
||||
// delete the axis from order
|
||||
order.erase(it);
|
||||
// insert axis at the beginning of order
|
||||
order.insert(order.begin(), axis);
|
||||
return order;
|
||||
}
|
||||
|
||||
// Thread offset is the thread index offset of two adjacent threads on the
|
||||
// reduction axis within the warp.
|
||||
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
|
||||
@@ -56,11 +67,11 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
|
||||
threadOffset = threadsPerWarp[sliceLayout.getDim()];
|
||||
} else {
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
||||
if (threadsPerWarp.size() == 1) {
|
||||
threadOffset = 1;
|
||||
} else {
|
||||
assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts");
|
||||
threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0];
|
||||
auto order = triton::gpu::getOrder(srcLayout);
|
||||
for (unsigned i = 0; i < order.size(); i++) {
|
||||
if (order[i] == axis)
|
||||
break;
|
||||
threadOffset *= threadsPerWarp[order[i]];
|
||||
}
|
||||
}
|
||||
return threadOffset;
|
||||
@@ -150,8 +161,10 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isWarpSynchronous() {
|
||||
auto argsLayout = getSrcLayout();
|
||||
return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1;
|
||||
auto srcLayout = getSrcLayout();
|
||||
auto srcShape = getSrcShape();
|
||||
return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] ==
|
||||
1;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfig() {
|
||||
@@ -502,10 +515,10 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src && dst && src.getVersionMajor() == 3 &&
|
||||
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
|
||||
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
|
||||
dst.getWarpsPerCTA()[1] == 1;
|
||||
}
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
|
||||
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
|
||||
}
|
||||
|
||||
@@ -521,7 +534,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
|
||||
srcTy.getElementType().isF16();
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
|
||||
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
|
||||
return true;
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
@@ -713,7 +726,10 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
|
||||
auto *currentOp = (slice)[currentIndex];
|
||||
// Compute and insert the backwardSlice starting from currentOp.
|
||||
backwardSlice.clear();
|
||||
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = backwardFilter;
|
||||
getBackwardSlice(currentOp, &backwardSlice, opt);
|
||||
slice.insert(backwardSlice.begin(), backwardSlice.end());
|
||||
|
||||
// Compute and insert the forwardSlice starting from currentOp.
|
||||
|
||||
@@ -14,7 +14,7 @@ add_mlir_conversion_library(NVGPUToLLVM
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRGPUOps
|
||||
MLIRGPUDialect
|
||||
MLIRGPUToNVVMTransforms
|
||||
MLIRGPUToROCDLTransforms
|
||||
MLIRGPUTransforms
|
||||
|
||||
@@ -29,8 +29,6 @@ const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;";
|
||||
const std::string Wgmma_Fence_Op = "wgmma.fence.sync.aligned;";
|
||||
const std::string Cga_Barrier_Sync_op = "barrier.cluster.sync.aligned;";
|
||||
const std::string Wgmma_Commit_Group_Op = "wgmma.commit_group.sync.aligned;";
|
||||
const std::string Wgmma_Wait_Group_Op =
|
||||
"wgmma.wait_group.sync.aligned #pendings;";
|
||||
const std::string Cluster_Wait_Op = "barrier.cluster.wait.aligned;";
|
||||
const std::string Fence_Mbarrier_Init_Op =
|
||||
"fence.mbarrier_init.release.cluster;";
|
||||
@@ -200,29 +198,6 @@ public:
|
||||
return {};
|
||||
}
|
||||
|
||||
Type getReturnType(std::vector<std::string> outputConstraints,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
auto ctx = rewriter.getContext();
|
||||
Type resTy;
|
||||
if (outputConstraints.empty()) {
|
||||
resTy = void_ty(ctx);
|
||||
} else {
|
||||
SmallVector<Type> retTys;
|
||||
for (auto &outputConstraint : outputConstraints) {
|
||||
assert(outputConstraint[0] == '=' &&
|
||||
"Constraint must be for an output");
|
||||
Type retTy = getTypeFromConstraint(outputConstraint[1], rewriter);
|
||||
retTys.push_back(retTy);
|
||||
}
|
||||
if (retTys.size() == 1) {
|
||||
resTy = retTys[0];
|
||||
} else {
|
||||
resTy = struct_ty(retTys);
|
||||
}
|
||||
}
|
||||
return resTy;
|
||||
}
|
||||
|
||||
std::string patchPtxAsm(mlir::Operation *op, std::string ptxAsm) const {
|
||||
std::vector<std::pair<int, int>> patchLocations;
|
||||
std::vector<std::string> patchValues;
|
||||
@@ -285,7 +260,8 @@ public:
|
||||
outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end());
|
||||
auto &ptxInstr = *ptxBuilder.create<PTXInstr>(ptxAsmPatched);
|
||||
ptxInstr(outputsAndOperands, /*onlyAttachMLIRArgs=*/true);
|
||||
auto retTy = getReturnType(outputConstraints, rewriter);
|
||||
auto retTy =
|
||||
op->getNumResults() == 0 ? void_ty(ctx) : op->getResult(0).getType();
|
||||
auto res = ptxBuilder.launch(rewriter, loc, retTy,
|
||||
/*hasSideEffects*/ hasSideEffects);
|
||||
if (op->getNumResults() == 0) {
|
||||
@@ -700,6 +676,45 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class WGMMAWaitGroupOpPattern
|
||||
: public NVGPUOpPatternBase<ttn::WGMMAWaitGroupOp,
|
||||
WGMMAWaitGroupOpPattern> {
|
||||
public:
|
||||
using Base =
|
||||
NVGPUOpPatternBase<ttn::WGMMAWaitGroupOp, WGMMAWaitGroupOpPattern>;
|
||||
using Base::Base;
|
||||
|
||||
std::vector<std::string>
|
||||
getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
|
||||
auto outputStructType = op.getType().cast<LLVM::LLVMStructType>();
|
||||
uint32_t numOutputRegs = outputStructType.getBody().size();
|
||||
std::string output =
|
||||
outputStructType.getBody().front().isF32() ? "=f" : "=r";
|
||||
return std::vector<std::string>(numOutputRegs, output);
|
||||
}
|
||||
|
||||
OperandsAndConstraints
|
||||
getOperandsAndConstraints(ttn::WGMMAWaitGroupOp op) const {
|
||||
OperandsAndConstraints operandsAndConstraints;
|
||||
auto input = op.getInput();
|
||||
operandsAndConstraints.push_back({input, "0"});
|
||||
return operandsAndConstraints;
|
||||
}
|
||||
|
||||
std::string getPtxAsm(ttn::WGMMAWaitGroupOp op) const {
|
||||
auto outputStructType = op.getType().dyn_cast<LLVM::LLVMStructType>();
|
||||
uint32_t numCRegs = outputStructType.getBody().size();
|
||||
std::string args = "";
|
||||
uint32_t asmOpIdx = 0;
|
||||
for (uint32_t i = 0; i < numCRegs; ++i) {
|
||||
args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ",");
|
||||
}
|
||||
auto ptxAsm = "// wait for regs: " + args + "\n\t" +
|
||||
"wgmma.wait_group.sync.aligned #pendings;";
|
||||
return ptxAsm;
|
||||
}
|
||||
};
|
||||
|
||||
class WGMMAOpPattern : public NVGPUOpPatternBase<ttn::WGMMAOp, WGMMAOpPattern> {
|
||||
public:
|
||||
using Base = NVGPUOpPatternBase<ttn::WGMMAOp, WGMMAOpPattern>;
|
||||
@@ -1072,7 +1087,6 @@ public:
|
||||
POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op)
|
||||
POPULATE_NVGPU_OP(ttn::CGABarrierSyncOp, Cga_Barrier_Sync_op)
|
||||
POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op)
|
||||
POPULATE_NVGPU_OP(ttn::WGMMAWaitGroupOp, Wgmma_Wait_Group_Op)
|
||||
POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op)
|
||||
POPULATE_NVGPU_OP(ttn::FenceMBarrierInitOp, Fence_Mbarrier_Init_Op)
|
||||
POPULATE_NVGPU_OP(ttn::CGABarrierArriveOp, Cga_Barrier_Arrive_Op)
|
||||
@@ -1100,7 +1114,8 @@ public:
|
||||
OffsetOfStmatrixV4OpPattern, MBarrierArriveOpPattern,
|
||||
ClusterArriveOpPattern, TMALoadTiledOpPattern,
|
||||
TMAStoreTiledOpPattern, LoadDSmemOpPattern, WGMMAOpPattern,
|
||||
StoreDSmemOpPattern, OffsetOfSts64OpPattern>(context);
|
||||
WGMMAWaitGroupOpPattern, StoreDSmemOpPattern,
|
||||
OffsetOfSts64OpPattern>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
|
||||
@@ -49,7 +49,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
|
||||
ASMBuilder
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRGPUOps
|
||||
MLIRGPUDialect
|
||||
MLIRGPUToNVVMTransforms
|
||||
MLIRGPUToROCDLTransforms
|
||||
MLIRGPUTransforms
|
||||
|
||||
@@ -146,10 +146,8 @@ struct DotWaitOpConversion
|
||||
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto pendings = op.getPendings();
|
||||
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(op.getLoc(), pendings);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(
|
||||
op, adaptor.getInput(), pendings);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -168,7 +168,9 @@ DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
int numRepM = ceil<unsigned>(shapePerCTA[0], instrShape[0] * wpt[0]);
|
||||
int numRepK = ceil<unsigned>(shapePerCTA[1], instrShape[2]);
|
||||
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
// The descriptor should be calculated based on the first warp of the
|
||||
// warpgroup.
|
||||
Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC));
|
||||
Value warpM = urem(warp, i32_val(wpt[0]));
|
||||
Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0]));
|
||||
|
||||
@@ -199,7 +201,7 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
int numRepK = ceil<unsigned>(shapePerCTA[0], instrShape[2]);
|
||||
int numRepN = ceil<unsigned>(shapePerCTA[1], instrShape[1] * wpt[1]);
|
||||
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC));
|
||||
Value warpMN = udiv(warp, i32_val(wpt[0]));
|
||||
Value warpN = urem(warpMN, i32_val(wpt[1]));
|
||||
Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1]));
|
||||
@@ -293,6 +295,26 @@ static bool isZero(Value v) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static SmallVector<Value> emitWait(ConversionPatternRewriter &rewriter,
|
||||
Location loc, SmallVector<Value> acc,
|
||||
int pendings) {
|
||||
SmallVector<Type> types(acc.size(), acc[0].getType());
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
|
||||
int i = 0;
|
||||
for (Value v : acc) {
|
||||
llvmStruct = insert_val(structTy, llvmStruct, v, i++);
|
||||
}
|
||||
Value res = rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, llvmStruct,
|
||||
pendings);
|
||||
SmallVector<Value> results;
|
||||
for (int i = 0; i < acc.size(); ++i) {
|
||||
results.push_back(extract_val(types[0], res, i));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value a, Value b, Value c, Value d,
|
||||
@@ -427,7 +449,7 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
|
||||
|
||||
if (sync)
|
||||
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, 0);
|
||||
mmaResults = emitWait(rewriter, loc, mmaResults, 0);
|
||||
|
||||
SmallVector<Value> results =
|
||||
unpackAccumulator(rewriter, loc, mmaResults, dTensorTy);
|
||||
|
||||
@@ -6,7 +6,24 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
|
||||
/* ----- FP8E5M2 ------ */
|
||||
// This data-type is the standard FP8E5M2 format
|
||||
static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
|
||||
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
} else {
|
||||
ret = "cvt.rn.satfinite.e5m2x2.f16x2 $0, $1; \n\t";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -356,11 +373,115 @@ const std::string Bf16_to_Fp8E5M2 =
|
||||
"or.b32 $0, nosign, sign; \n" // restore sign
|
||||
"}";
|
||||
#endif
|
||||
=======
|
||||
static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret = "{ \n"
|
||||
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
|
||||
"}";
|
||||
} else {
|
||||
ret = "cvt.rn.f16x2.e5m2x2 $0, $1; \n\t";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
|
||||
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
|
||||
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
|
||||
// exponent compensate = 112
|
||||
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
} else {
|
||||
ret = "{ \n"
|
||||
".reg .b32 a; \n"
|
||||
".reg .f16 a<2>; \n"
|
||||
".reg .b16 b<2>; \n"
|
||||
"cvt.rn.f16x2.e5m2x2 a, $1; \n"
|
||||
"mov.b32 {a0, a1}, a; \n"
|
||||
"cvt.bf16.f16 b0, a0; \n"
|
||||
"cvt.bf16.f16 b1, a1; \n"
|
||||
"mov.b32 $0, {b0, b1}; \n"
|
||||
"}";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
|
||||
static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
std::string ret;
|
||||
if (!hasNativeFP) {
|
||||
ret =
|
||||
"{ \n" // bf16=fp8>>3 + 112<<7
|
||||
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000
|
||||
".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111
|
||||
"mov.u32 fp8_min, 0x38003800; \n" // so bf16_min = 0x3800
|
||||
"mov.u32 fp8_max, 0x57e057e0; \n" // so bf16_max = 0x57e0
|
||||
"mov.u32 rn_, 0x00100010; \n" // round to nearest
|
||||
"and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000
|
||||
"and.b32 sign1, $2, 0x80008000; \n" // (store sign)
|
||||
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
|
||||
"and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff
|
||||
"and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign)
|
||||
|
||||
// nosign = clamp(nosign, min, max)
|
||||
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
|
||||
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
|
||||
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
|
||||
"min.u32 nosign_0_0, nosign_0_0, 0x57e00000; \n"
|
||||
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
|
||||
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
|
||||
"min.u32 nosign_0_1, nosign_0_1, 0x57e0; \n"
|
||||
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
|
||||
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
|
||||
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
|
||||
"min.u32 nosign_1_0, nosign_1_0, 0x57e00000; \n"
|
||||
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
|
||||
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
|
||||
"min.u32 nosign_1_1, nosign_1_1, 0x57e0; \n"
|
||||
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
|
||||
|
||||
"add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_
|
||||
"add.u32 nosign1, nosign1, rn_; \n" // (round to nearest)
|
||||
"sub.u32 nosign0, nosign0, 0x38003800; \n" // nosign0-=0x38003800
|
||||
"sub.u32 nosign1, nosign1, 0x38003800; \n" // (compensate offset)
|
||||
"shl.b32 nosign0, nosign0, 3; \n" // nosign0 <<= 3
|
||||
"shl.b32 nosign1, nosign1, 3; \n" // shift into to fp8e4
|
||||
"prmt.b32 nosign, nosign0, nosign1, 0x7531; \n" // nosign0 = 0xf100f200
|
||||
// nosign1 = 0xf300f400
|
||||
// nosign = 0xf3f4f1f2
|
||||
"or.b32 $0, nosign, sign; \n" // restore sign
|
||||
"}";
|
||||
} else {
|
||||
ret = "{ \n"
|
||||
".reg .b16 a<2>; \n"
|
||||
".reg .f32 b<2>; \n"
|
||||
"mov.b32 {a0, a1}, $1; \n"
|
||||
"cvt.f32.bf16 b0, a0; \n"
|
||||
"cvt.f32.bf16 b1, a1; \n"
|
||||
"cvt.rn.satfinite.e5m2x2.f32 $0, b0, b1; \n"
|
||||
"}";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
/* ----- FP8E4M3B15 ------ */
|
||||
// This data-type is a variant of the standard FP8E4M3 format.
|
||||
// It was designed for fast software conversion to FP16 on
|
||||
// nvidia GPUs that do not support it natively.
|
||||
<<<<<<< HEAD
|
||||
// Specifically, this data-type:
|
||||
// - has infinities
|
||||
// - has multiple nans (when all exponent bits are 1)
|
||||
@@ -404,6 +525,11 @@ Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
};
|
||||
}
|
||||
#else
|
||||
=======
|
||||
// This is the same format as FP8E4M3Nv, but:
|
||||
// - the exponent bias is 15 instead of 7
|
||||
// - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
const std::string Fp8E4M3B15_to_Fp16 =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
@@ -416,6 +542,7 @@ const std::string Fp8E4M3B15_to_Fp16 =
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"shl.b32 $1, b1, 7; \n"
|
||||
"} \n";
|
||||
<<<<<<< HEAD
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -464,6 +591,10 @@ Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
=======
|
||||
|
||||
static const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
std::string ret;
|
||||
ret += "{ \n"
|
||||
".reg .pred p<4>; \n"
|
||||
@@ -509,6 +640,7 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
|
||||
// $0 = (($2 << 1) & 0x80008000u) | (($2 << 7) & 0x3f803f80u);
|
||||
// $1 = (($2 << 0) & 0x80008000u) | (($2 << 0) & 0x3f803f80u);
|
||||
// WARN: subnormal (0bs0000xxx) are not handled
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -540,6 +672,9 @@ Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
}
|
||||
#else
|
||||
const std::string Fp8E4M3B15x4_to_Fp16 =
|
||||
=======
|
||||
static const std::string Fp8E4M3B15x4_to_Fp16 =
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"add.u32 a0, $2, $2; \n"
|
||||
@@ -557,6 +692,7 @@ const std::string Fp8E4M3B15x4_to_Fp16 =
|
||||
// ((e4.y >> 0) & (0x80008000u >> 0)) |
|
||||
// ((e4.y >> 0) & (0x3f803f80u >> 0)) ;
|
||||
// WARN: subnormal (0bs0000xxx) are not handled
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -591,6 +727,9 @@ Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
=======
|
||||
static const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"shr.b32 a0, $1, 1; \n"
|
||||
@@ -904,17 +1043,18 @@ const std::string Bf16_to_Fp8E4M3 =
|
||||
#endif
|
||||
|
||||
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
|
||||
const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
|
||||
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
|
||||
"}";
|
||||
static const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
|
||||
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
|
||||
"}";
|
||||
// Fp16 (x2) -> Fp8E4M3 (x2) (packed)
|
||||
const std::string Fp16_to_Fp8E4M3Nv = "{ \n"
|
||||
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
|
||||
"}";
|
||||
static const std::string Fp16_to_Fp8E4M3Nv =
|
||||
"{ \n"
|
||||
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
|
||||
"}";
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
|
||||
const std::string Fp8E4M3Nv_to_Bf16 =
|
||||
static const std::string Fp8E4M3Nv_to_Bf16 =
|
||||
"{ \n"
|
||||
".reg .b32 a; \n"
|
||||
".reg .f16 a<2>; \n"
|
||||
@@ -927,7 +1067,7 @@ const std::string Fp8E4M3Nv_to_Bf16 =
|
||||
"}";
|
||||
|
||||
// Bf16 (x2) -> Fp8E4M3 (x2) (packed)
|
||||
const std::string Bf16_to_Fp8E4M3Nv =
|
||||
static const std::string Bf16_to_Fp8E4M3Nv =
|
||||
"{ \n"
|
||||
".reg .b16 a<2>; \n"
|
||||
".reg .f32 b<2>; \n"
|
||||
@@ -938,7 +1078,7 @@ const std::string Bf16_to_Fp8E4M3Nv =
|
||||
"}";
|
||||
|
||||
/* ----- Packed integer to BF16 ------ */
|
||||
const std::string S8_to_Bf16 =
|
||||
static const std::string S8_to_Bf16 =
|
||||
"{ \n"
|
||||
".reg .s8 s<4>; \n"
|
||||
".reg .f32 f<4>; \n"
|
||||
@@ -952,6 +1092,12 @@ const std::string S8_to_Bf16 =
|
||||
"}";
|
||||
#endif
|
||||
|
||||
// Fp32 (x2) -> Fp8 (x2) (packed)
|
||||
static const std::string Fp32_to_Fp8E4M3Nv =
|
||||
"cvt.rn.satfinite.e4m3x2.f32 $0, $2, $1; \n";
|
||||
static const std::string Fp32_to_Fp8E5M2 =
|
||||
"cvt.rn.satfinite.e5m2x2.f32 $0, $2, $1; \n";
|
||||
|
||||
static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
|
||||
Type inType, Type ouType) {
|
||||
auto inTensorTy = inType.dyn_cast<RankedTensorType>();
|
||||
@@ -1383,9 +1529,14 @@ struct FpToFpOpConversion
|
||||
// F8 -> F16
|
||||
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
|
||||
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
|
||||
<<<<<<< HEAD
|
||||
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
|
||||
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
|
||||
=======
|
||||
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)},
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
// F16 -> F8
|
||||
#ifdef USE_ROCM
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
|
||||
@@ -1393,27 +1544,44 @@ struct FpToFpOpConversion
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
|
||||
#endif
|
||||
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
|
||||
<<<<<<< HEAD
|
||||
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
|
||||
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
|
||||
// F8 -> BF16
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
#ifndef USE_ROCM
|
||||
=======
|
||||
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)},
|
||||
// F8 -> BF16
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
|
||||
#endif
|
||||
// BF16 -> F8
|
||||
<<<<<<< HEAD
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
|
||||
#ifndef USE_ROCM
|
||||
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
|
||||
#endif
|
||||
=======
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2(computeCapability >= 90)},
|
||||
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
|
||||
// F32 -> F8
|
||||
{{F32TyID, F8E4M3TyID}, Fp32_to_Fp8E4M3Nv},
|
||||
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
};
|
||||
int inVecWidthBits = 32;
|
||||
int outVecWidthBits = 32;
|
||||
if (srcTy.isFloat8E4M3FNUZ()) {
|
||||
if (srcTy.isFloat8E4M3FNUZ() ||
|
||||
(computeCapability >= 90 && srcTy.isFloat8E5M2())) {
|
||||
inVecWidthBits = 16;
|
||||
outVecWidthBits = 32;
|
||||
}
|
||||
if (dstTy.isFloat8E4M3FNUZ()) {
|
||||
if (dstTy.isFloat8E4M3FNUZ() ||
|
||||
(computeCapability >= 90 && dstTy.isFloat8E5M2())) {
|
||||
inVecWidthBits = 32;
|
||||
outVecWidthBits = 16;
|
||||
}
|
||||
@@ -1450,18 +1618,24 @@ struct FpToFpOpConversion
|
||||
|
||||
size_t numElements = 4;
|
||||
if (srcElementType.isFloat8E4M3FNUZ() ||
|
||||
dstElementType.isFloat8E4M3FNUZ()) {
|
||||
dstElementType.isFloat8E4M3FNUZ() ||
|
||||
(computeCapability >= 90 &&
|
||||
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) {
|
||||
numElements = 2;
|
||||
}
|
||||
bool isSrcFP32 = srcElementType.isF32();
|
||||
bool useFP16IntermediateSrc =
|
||||
srcElementType.isF32() &&
|
||||
!(computeCapability >= 90 &&
|
||||
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
|
||||
bool isDstFP32 = dstElementType.isF32();
|
||||
auto cvtFunc = getConversionFunc(isSrcFP32 ? f16_ty : srcElementType,
|
||||
isDstFP32 ? f16_ty : dstElementType);
|
||||
auto cvtFunc =
|
||||
getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType,
|
||||
isDstFP32 ? f16_ty : dstElementType);
|
||||
SmallVector<Value> inVals;
|
||||
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
|
||||
inVals.push_back(operands[i][0]);
|
||||
}
|
||||
if (isSrcFP32)
|
||||
if (useFP16IntermediateSrc)
|
||||
for (Value &v : inVals)
|
||||
v = convertFp32ToFp16(loc, rewriter, v);
|
||||
inVals.resize(numElements,
|
||||
@@ -2115,18 +2289,18 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
|
||||
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
|
||||
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
|
||||
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
||||
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
||||
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
|
||||
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
|
||||
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
|
||||
POPULATE_BINARY_OP(arith::MaxFOp, LLVM::MaxNumOp) // fmax
|
||||
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
|
||||
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
|
||||
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
|
||||
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
|
||||
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
|
||||
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
|
||||
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
|
||||
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
|
||||
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
|
||||
POPULATE_BINARY_OP(arith::MinimumFOp, LLVM::MinNumOp) // fmin
|
||||
POPULATE_BINARY_OP(arith::MaximumFOp, LLVM::MaxNumOp) // fmax
|
||||
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
|
||||
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
|
||||
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
|
||||
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
|
||||
#undef POPULATE_BINARY_OP
|
||||
|
||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||
|
||||
@@ -1549,10 +1549,12 @@ struct InsertSliceAsyncOpConversion
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto resTy = dst.getType().cast<RankedTensorType>();
|
||||
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
|
||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
|
||||
"Unexpected srcLayout in InsertSliceAsyncOpConversion"));
|
||||
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
assert(srcShape.size() == 2 &&
|
||||
assert((srcShape.size() == 1 || srcShape.size() == 2) &&
|
||||
"insert_slice_async: Unexpected rank of %src");
|
||||
|
||||
Value llDst = adaptor.getDst();
|
||||
@@ -1617,25 +1619,15 @@ struct InsertSliceAsyncOpConversion
|
||||
unsigned numElems = getTotalElemsPerThread(srcTy);
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
DenseMap<unsigned, Value> sharedPtrs =
|
||||
getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy,
|
||||
smemObj, rewriter, offsetVals, srcStrides);
|
||||
|
||||
// If perPhase * maxPhase > threadsPerCTA, we will have elements
|
||||
// that share the same tile indices. The index calculation will
|
||||
// be cached.
|
||||
auto numSwizzleRows = std::max<unsigned>(
|
||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||
// A sharedLayout encoding has a "vec" parameter.
|
||||
// On the column dimension, if inVec > outVec, it means we have to divide
|
||||
// single vector read into multiple ones
|
||||
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
|
||||
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcTy);
|
||||
|
||||
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
||||
// 16 * 8 = 128bits
|
||||
auto maxBitWidth =
|
||||
|
||||
@@ -419,16 +419,15 @@ private:
|
||||
getMultiDimWarpId(helper, warpId, loc, rewriter);
|
||||
Value warpIdAxis = multiDimWarpId[axis];
|
||||
|
||||
if (!helper.isReductionOnLayoutFastAxis()) {
|
||||
std::reverse(order.begin(), order.end());
|
||||
}
|
||||
auto smemOrder = helper.getOrderWithAxisAtBeginning();
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
SmallVector<Value> acc = it.second;
|
||||
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = warpIdAxis;
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order);
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, writeIdx, smemShape, smemOrder);
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto elemPtrTy = getElementPtrType(op, i);
|
||||
Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset);
|
||||
@@ -513,10 +512,7 @@ private:
|
||||
Location loc = op.getLoc();
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
auto axis = op.getAxis();
|
||||
auto order = getOrder(srcLayout);
|
||||
if (!helper.isReductionOnLayoutFastAxis()) {
|
||||
std::reverse(order.begin(), order.end());
|
||||
}
|
||||
auto smemOrder = helper.getOrderWithAxisAtBeginning();
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
@@ -532,7 +528,7 @@ private:
|
||||
SmallVector<Value> readIdx = resultIndices[j];
|
||||
readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0));
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, readIdx, smemShape, order);
|
||||
linearize(rewriter, loc, readIdx, smemShape, smemOrder);
|
||||
Value readPtr =
|
||||
gep(getElementPtrType(op, i), smemBases[i], readOffset);
|
||||
resultVals[j] = load(readPtr);
|
||||
|
||||
@@ -622,10 +622,13 @@ struct AllocTensorOpConversion
|
||||
// TODO: we need to modify the pipeline pass to give a proper shared
|
||||
// encoding to 3D tensors
|
||||
SmallVector<unsigned> newOrder;
|
||||
if (resultTy.getShape().size() == 3)
|
||||
newOrder = {1 + order[0], 1 + order[1], 0};
|
||||
else
|
||||
if (resultTy.getShape().size() != order.size()) {
|
||||
for (auto i = 0; i < order.size(); ++i)
|
||||
newOrder.push_back(order[i] + 1);
|
||||
newOrder.push_back(0);
|
||||
} else {
|
||||
newOrder = SmallVector<unsigned>(order.begin(), order.end());
|
||||
}
|
||||
|
||||
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
|
||||
auto smemObj =
|
||||
@@ -659,10 +662,13 @@ struct ExtractSliceOpConversion
|
||||
SmallVector<Value, 4> opOffsetVals;
|
||||
SmallVector<Value, 4> offsetVals;
|
||||
auto mixedOffsets = op.getMixedOffsets();
|
||||
for (auto i = 0; i < mixedOffsets.size(); ++i) {
|
||||
if (op.isDynamicOffset(i))
|
||||
opOffsetVals.emplace_back(adaptor.getOffsets()[i]);
|
||||
else
|
||||
for (auto i = 0, j = 0; i < mixedOffsets.size(); ++i) {
|
||||
if (op.isDynamicOffset(i)) {
|
||||
// adaptor.getOffsets() returns list of variable offsets. the size of
|
||||
// the list may not be the same as mixedOffsets
|
||||
opOffsetVals.emplace_back(adaptor.getOffsets()[j]);
|
||||
++j;
|
||||
} else
|
||||
opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
|
||||
offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i]));
|
||||
}
|
||||
|
||||
@@ -146,7 +146,8 @@ protected:
|
||||
}
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
/*dsoLocal*/ false, LLVM::CConv::C, attributes);
|
||||
/*dsoLocal*/ false, LLVM::CConv::C, /*comdat=*/SymbolRefAttr{},
|
||||
attributes);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
||||
@@ -361,8 +362,13 @@ public:
|
||||
unsigned numElemsPerSwizzlingRow =
|
||||
swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth();
|
||||
Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow);
|
||||
unsigned leadingDimOffset =
|
||||
numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
|
||||
unsigned leadingDimOffset;
|
||||
if (outOrder.size() == 2) {
|
||||
leadingDimOffset = numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
|
||||
} else {
|
||||
leadingDimOffset = numElemsPerSwizzlingRow;
|
||||
}
|
||||
|
||||
Value leadingDimOffsetVal = i32_val(leadingDimOffset);
|
||||
// Return values
|
||||
DenseMap<unsigned, Value> ret;
|
||||
@@ -374,9 +380,15 @@ public:
|
||||
// Extract multi dimensional index for current element
|
||||
auto idx = srcIndices[elemIdx];
|
||||
Value idxCol = idx[outOrder[0]]; // contiguous dimension
|
||||
Value idxRow = idx[outOrder[1]]; // discontiguous dimension
|
||||
Value idxRow, strideRow;
|
||||
if (outOrder.size() == 2) {
|
||||
idxRow = idx[outOrder[1]]; // discontiguous dimension
|
||||
strideRow = srcStrides[outOrder[1]];
|
||||
} else {
|
||||
idxRow = i32_val(0);
|
||||
strideRow = i32_val(0);
|
||||
}
|
||||
Value strideCol = srcStrides[outOrder[0]];
|
||||
Value strideRow = srcStrides[outOrder[1]];
|
||||
// compute phase = (row // perPhase) % maxPhase
|
||||
Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase));
|
||||
// extract dynamic/static offset for immediate offsetting
|
||||
@@ -428,10 +440,16 @@ public:
|
||||
offset = add(offset, add(rowOff, mul(colOff, strideCol)));
|
||||
Value currPtr = gep(dstPtrTy, dstPtrBase, offset);
|
||||
// compute immediate offset
|
||||
Value immedateOff =
|
||||
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
|
||||
i32_val(immedateOffCol));
|
||||
ret[elemIdx] = gep(dstPtrTy, currPtr, immedateOff);
|
||||
Value immediateOff;
|
||||
if (outOrder.size() == 2) {
|
||||
immediateOff =
|
||||
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
|
||||
i32_val(immedateOffCol));
|
||||
} else {
|
||||
immediateOff = i32_val(immedateOffCol);
|
||||
}
|
||||
|
||||
ret[elemIdx] = gep(dstPtrTy, currPtr, immediateOff);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -371,13 +371,15 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Type type = val.getType();
|
||||
if (type != i32_ty) {
|
||||
val = bitcast(val, int_ty(bits));
|
||||
val = zext(i32_ty, val);
|
||||
if (bits < 32)
|
||||
val = zext(i32_ty, val);
|
||||
}
|
||||
Value mask = i32_val(0xFFFFFFFF);
|
||||
Value result = rewriter.create<NVVM::ShflOp>(loc, i32_ty, mask, val, i, clamp,
|
||||
mode, UnitAttr());
|
||||
if (type != i32_ty) {
|
||||
result = trunc(int_ty(bits), result);
|
||||
if (bits < 32)
|
||||
result = trunc(int_ty(bits), result);
|
||||
result = bitcast(result, type);
|
||||
}
|
||||
return result;
|
||||
|
||||
@@ -97,7 +97,7 @@
|
||||
ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \
|
||||
} while (0)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
|
||||
#define null(...) rewriter.create<LLVM::ZeroOp>(loc, __VA_ARGS__)
|
||||
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
|
||||
|
||||
// Types
|
||||
|
||||
@@ -115,8 +115,8 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
// Floating point
|
||||
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
|
||||
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
|
||||
GenericOpPattern<arith::MaximumFOp>, GenericOpPattern<arith::MaxSIOp>,
|
||||
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinimumFOp>,
|
||||
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Interfaces/FunctionImplementation.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
|
||||
@@ -8,6 +8,6 @@ add_mlir_dialect_library(TritonGPUIR
|
||||
TritonGPUAttrDefsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRGPUOps
|
||||
MLIRGPUDialect
|
||||
TritonIR
|
||||
)
|
||||
|
||||
@@ -1603,7 +1603,7 @@ ParseResult parseInsertSliceOp(OpAsmParser &parser, OperationState &result) {
|
||||
result.operands))
|
||||
return failure();
|
||||
|
||||
// Deduce operand_segment_sizes from the number of the operands.
|
||||
// Deduce operandSegmentSizes from the number of the operands.
|
||||
auto operandSegmentSizesAttrName =
|
||||
OpT::getOperandSegmentSizesAttrName(result.name);
|
||||
result.addAttribute(
|
||||
@@ -1616,7 +1616,7 @@ template <class OpT>
|
||||
void printInsertSliceOp(OpAsmPrinter &printer, OpT insertSliceOp) {
|
||||
printer << " ";
|
||||
printer << insertSliceOp.getOperation()->getOperands();
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
// "operandSegmentSizes" can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(
|
||||
insertSliceOp->getAttrs(),
|
||||
{insertSliceOp.getOperandSegmentSizesAttrName()});
|
||||
|
||||
@@ -139,7 +139,10 @@ class BlockedToMMA : public mlir::RewritePattern {
|
||||
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
|
||||
int origBitWidth = finalBitWidth;
|
||||
SetVector<Operation *> slice;
|
||||
mlir::getBackwardSlice(x, &slice, bwdFilter);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = bwdFilter;
|
||||
getBackwardSlice(x, &slice, opt);
|
||||
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
|
||||
if (firstOp)
|
||||
if (Value arg = firstOp->getOperand(0))
|
||||
@@ -235,8 +238,11 @@ public:
|
||||
if (versionMajor == 1) {
|
||||
SetVector<Operation *> aBwdSlices, bBwdSlices;
|
||||
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
|
||||
getBackwardSlice(a, &aBwdSlices, {isCvt});
|
||||
getBackwardSlice(b, &bBwdSlices, {isCvt});
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = isCvt;
|
||||
getBackwardSlice(a, &aBwdSlices, opt);
|
||||
getBackwardSlice(b, &bBwdSlices, opt);
|
||||
// get the source of the first conversion found in slices
|
||||
auto getCvtArgOrder = [](Operation *op) {
|
||||
return cast<ConvertLayoutOp>(op)
|
||||
|
||||
@@ -98,7 +98,9 @@ public:
|
||||
// and all operations between the load and the conversion
|
||||
// should be layout preserving
|
||||
SetVector<Operation *> slice;
|
||||
getBackwardSlice(op, &slice);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
getBackwardSlice(op, &slice, opt);
|
||||
int loadIdx = -1;
|
||||
bool checkOp = false;
|
||||
for (int i = 0; i < slice.size(); i++) {
|
||||
|
||||
@@ -160,6 +160,8 @@ class LoopPipeliner {
|
||||
void checkOpShareBarriers(SetVector<Operation *> &ops);
|
||||
int numLoadsRequireAsyncWait = 0;
|
||||
int numLoadsRequireMBarrier = 0;
|
||||
// Number of buffers to allocate for each input.
|
||||
int numSharedMemorySlices = 0;
|
||||
|
||||
/// Iterator values
|
||||
Value nextIV;
|
||||
@@ -280,9 +282,12 @@ class LoopPipeliner {
|
||||
|
||||
public:
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs,
|
||||
bool mode, ConsumerReleaseMap &consumerReleaseMap)
|
||||
bool mode, int numSharedMemorySlices,
|
||||
ConsumerReleaseMap &consumerReleaseMap)
|
||||
: forOp(forOp), numStages(numStages), numWarps(numWarps),
|
||||
numCTAs(numCTAs), mode(mode), consumerReleaseMap(consumerReleaseMap) {
|
||||
numCTAs(numCTAs), mode(mode),
|
||||
numSharedMemorySlices(numSharedMemorySlices),
|
||||
consumerReleaseMap(consumerReleaseMap) {
|
||||
// cache yieldOp
|
||||
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
}
|
||||
@@ -644,7 +649,7 @@ void LoopPipeliner::createBufferTypes() {
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
bufferShape.insert(bufferShape.begin(), numSharedMemorySlices);
|
||||
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
|
||||
Attribute sharedEnc;
|
||||
if (auto dotOpEnc = cvt.getType()
|
||||
@@ -946,6 +951,11 @@ void LoopPipeliner::emitPrologue() {
|
||||
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||
iv.getLoc(), pipelineIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
||||
Value numSlices = builder.create<arith::ConstantIntOp>(
|
||||
iv.getLoc(), numSharedMemorySlices, 32);
|
||||
Value _0 = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
pipelineIterIdx = getBoundedIterationValue(builder, pipelineIterIdx,
|
||||
numSlices, pipelineIterIdx, _0);
|
||||
// Some values have not been used by any ops in the loop body
|
||||
for (BlockArgument arg : forOp.getRegionIterArgs())
|
||||
setValueMappingYield(arg, valueMapping[arg][stage], stage + 1);
|
||||
@@ -1220,11 +1230,13 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
|
||||
Value _1 = builder.create<arith::ConstantIntOp>(idxLoc, 1, 32);
|
||||
Value numStagesVal =
|
||||
builder.create<arith::ConstantIntOp>(idxLoc, numStages, 32);
|
||||
Value numSlices =
|
||||
builder.create<arith::ConstantIntOp>(idxLoc, numSharedMemorySlices, 32);
|
||||
|
||||
// nextWaitIdx
|
||||
Value waitIdxPlusOne = builder.create<arith::AddIOp>(idxLoc, curWaitIdx, _1);
|
||||
Value nextWaitIdx = getBoundedIterationValue(
|
||||
builder, waitIdxPlusOne, numStagesVal, waitIdxPlusOne, _0);
|
||||
Value nextWaitIdx = getBoundedIterationValue(builder, waitIdxPlusOne,
|
||||
numSlices, waitIdxPlusOne, _0);
|
||||
|
||||
// Indices of InsertSliceAsyncOp and ExtractSliceOp
|
||||
Value insertSliceIndex = pipelineIterIdx;
|
||||
@@ -1417,9 +1429,8 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
|
||||
// Bump pipelineIterIdx
|
||||
Value pipelineIterIdxPlusOne =
|
||||
builder.create<arith::AddIOp>(idxLoc, pipelineIterIdx, _1);
|
||||
pipelineIterIdx =
|
||||
getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal,
|
||||
pipelineIterIdxPlusOne, _0);
|
||||
pipelineIterIdx = getBoundedIterationValue(
|
||||
builder, pipelineIterIdxPlusOne, numSlices, pipelineIterIdxPlusOne, _0);
|
||||
|
||||
// Bump curWaitIdx
|
||||
curWaitIdx = nextWaitIdx;
|
||||
@@ -1516,10 +1527,23 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
// applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
|
||||
llvm::SmallVector<scf::ForOp> newForOps;
|
||||
|
||||
// Currently we schedule stage 0 after stage `numStages - 1` during
|
||||
// pipelining therefore we only need `numStages - 1` slice of memory.
|
||||
// On Hopper we have a separate post-processing that pipelines wgmma so we
|
||||
// need an extra buffer for each input.
|
||||
// Note that an alternative would be to keep allocating `numStages` buffers
|
||||
// and remove the barrier between the loads from shared memory and the
|
||||
// copies from global to shared. This would require improving existing
|
||||
// membar analysis.
|
||||
int numSharedMemorySlices =
|
||||
computeCapability < 90 ? numStages - 1 : numStages;
|
||||
|
||||
// Do the pipelining
|
||||
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
||||
LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps,
|
||||
this->numCTAs, mode, consumerReleaseMap);
|
||||
this->numCTAs, mode, numSharedMemorySlices,
|
||||
consumerReleaseMap);
|
||||
if (pipeliner.initialize().failed())
|
||||
return;
|
||||
|
||||
@@ -1593,7 +1617,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
|
||||
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
|
||||
/// dots to be pipelined
|
||||
SetVector<Value> dots;
|
||||
SmallVector<tt::DotOp> dots;
|
||||
SmallVector<unsigned> resultNeedSync;
|
||||
for (Operation &op : *loop) {
|
||||
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
|
||||
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
@@ -1615,8 +1640,11 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
if (!CArg || !CArg.hasOneUse())
|
||||
valid = false;
|
||||
|
||||
if (valid)
|
||||
dots.insert(dotOp);
|
||||
if (valid) {
|
||||
dots.push_back(dotOp);
|
||||
resultNeedSync.push_back(
|
||||
dotOp->getUses().begin()->getOperandNumber());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1627,39 +1655,39 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
return;
|
||||
|
||||
OpBuilder builder(forOp);
|
||||
|
||||
// 0. insert dot_wait after the last dot in the loop
|
||||
Value dot = dots.back();
|
||||
auto loc = dot.getLoc();
|
||||
builder.setInsertionPointAfter(dot.getDefiningOp());
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(loc, dots.size());
|
||||
// 0. insert dot_wait after the last dot in the loop as we implicitly pipeline
|
||||
// wgmma ops by one stage.
|
||||
// This is needed to prevent shared memory inputs to be overriden before the
|
||||
// operation is completed.
|
||||
// TODO: merge this with the rest of the pipelining transformation and look at
|
||||
// a better representation for async dots.
|
||||
tt::DotOp lastDot = dots.back();
|
||||
builder.setInsertionPointAfter(lastDot);
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
|
||||
lastDot.getLoc(), lastDot.getResult(), dots.size());
|
||||
|
||||
// 1. replace Dot with DotAsync
|
||||
for (size_t idx = 0; idx < dots.size(); ++idx) {
|
||||
Value dot = dots[idx];
|
||||
auto dotOp = cast<tt::DotOp>(dot.getDefiningOp());
|
||||
builder.setInsertionPoint(dot.getDefiningOp());
|
||||
tt::DotOp dotOp = dots[idx];
|
||||
builder.setInsertionPoint(dotOp);
|
||||
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
|
||||
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
|
||||
dotOp.getMaxNumImpreciseAcc());
|
||||
dot.replaceAllUsesWith(dotAsync.getResult());
|
||||
updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1);
|
||||
dot.getDefiningOp()->erase();
|
||||
dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dotOp.replaceAllUsesWith(dotAsync.getResult());
|
||||
updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1);
|
||||
dotOp->erase();
|
||||
}
|
||||
|
||||
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
Value loopNotEmpty = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
|
||||
forOp.getUpperBound());
|
||||
// TODO[goostavz]: it's a workaround to put the DotWaitOp in an IfOp for
|
||||
// a bug in ptxas which mistakenly analysis the control flow and turn the GMMA
|
||||
// into synchronuous implementation for safety.
|
||||
// Remove this If once the bug is fixed.
|
||||
auto ifOp = builder.create<scf::IfOp>(loc, ArrayRef<Type>{}, loopNotEmpty,
|
||||
/*hasElse*/ false);
|
||||
builder.setInsertionPointToStart(ifOp.thenBlock());
|
||||
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), 0);
|
||||
for (unsigned resultIndex : resultNeedSync) {
|
||||
Value result = forOp->getResult(resultIndex);
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
auto dotWait =
|
||||
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
|
||||
}
|
||||
}
|
||||
|
||||
Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc,
|
||||
|
||||
@@ -31,6 +31,7 @@ using triton::gpu::SliceEncodingAttr;
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
<<<<<<< HEAD
|
||||
// convert(blocked, dot_operand) ->
|
||||
// convert(blocked, mma) + convert(mma, dot_operand)
|
||||
// if this value is itself the result of a dot operation
|
||||
@@ -102,6 +103,9 @@ public:
|
||||
};
|
||||
|
||||
//
|
||||
=======
|
||||
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
class ConvertDotConvert : public mlir::RewritePattern {
|
||||
public:
|
||||
ConvertDotConvert(mlir::MLIRContext *context)
|
||||
@@ -233,12 +237,17 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
|
||||
getForwardSlice(currentValue, &forwardSlice);
|
||||
for (Operation *op : forwardSlice) {
|
||||
if (auto convertOp = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
if (convertOp.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.isa<triton::gpu::MmaEncodingAttr>())
|
||||
return true;
|
||||
Attribute dstEncoding = convertOp.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding();
|
||||
if (auto mmaLayout =
|
||||
dstEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>())
|
||||
return (mmaLayout.getVersionMajor() > 1) ? true
|
||||
: mmaLayout == encoding;
|
||||
if (dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return encoding.cast<triton::gpu::MmaEncodingAttr>()
|
||||
.getVersionMajor() > 1;
|
||||
}
|
||||
auto yield = dyn_cast<scf::YieldOp>(op);
|
||||
if (!yield)
|
||||
@@ -560,6 +569,15 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
|
||||
return rewrittenValue;
|
||||
OpBuilder rewriter(value.getContext());
|
||||
rewriter.setInsertionPointAfterValue(rewrittenValue);
|
||||
// Workaround: The pipeliner will insert async.wait after a pipelined loop
|
||||
// to ensure that there is no pending copies and it is safe to re-use shared
|
||||
// memory. We shouldn't insert ops that may use shared memory in between the
|
||||
// loop and the async.wait. This is a hack until we fix the IR
|
||||
// representation of async wait.
|
||||
if (Operation *op = rewrittenValue.getDefiningOp()) {
|
||||
if (isa<triton::gpu::AsyncWaitOp>(op->getNextNode()))
|
||||
rewriter.setInsertionPointAfter(op->getNextNode());
|
||||
}
|
||||
auto tmpType = RankedTensorType::get(tensorType.getShape(),
|
||||
tensorType.getElementType(), encoding);
|
||||
Value converted = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -1122,7 +1140,6 @@ public:
|
||||
hoistConvert(m);
|
||||
|
||||
mlir::RewritePatternSet decomposePatterns(context);
|
||||
decomposePatterns.add<DecomposeDotOperand>(context);
|
||||
decomposePatterns.add<ConvertDotConvert>(context);
|
||||
if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
|
||||
.failed()) {
|
||||
|
||||
@@ -91,7 +91,7 @@ private:
|
||||
// suport ForOp only
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
|
||||
// prologue
|
||||
auto iterOperands = forOp.getIterOperands();
|
||||
auto iterOperands = forOp.getInitArgs();
|
||||
if (argNum == 0)
|
||||
return false;
|
||||
if (dependOnSharedEncOperand(iterOperands[argNum - 1]))
|
||||
|
||||
@@ -628,12 +628,13 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
|
||||
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
|
||||
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
|
||||
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
|
||||
arith::MaxFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinFOp,
|
||||
arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp,
|
||||
arith::NegFOp, arith::OrIOp, arith::RemFOp, arith::RemSIOp,
|
||||
arith::RemUIOp, arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp,
|
||||
arith::SIToFPOp, arith::SubFOp, arith::SubIOp, arith::TruncFOp,
|
||||
arith::TruncIOp, arith::UIToFPOp, arith::XOrIOp>(op))
|
||||
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp,
|
||||
arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp,
|
||||
arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp,
|
||||
arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp,
|
||||
arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp,
|
||||
arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
|
||||
arith::UIToFPOp, arith::XOrIOp>(op))
|
||||
return true;
|
||||
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
|
||||
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,
|
||||
|
||||
@@ -220,7 +220,9 @@ public:
|
||||
SetVector<Operation *> backwardSlice;
|
||||
mod.walk([&](triton::MakeTensorPtrOp op) -> void {
|
||||
assert(isa<triton::FuncOp>(op->getParentOp()));
|
||||
getBackwardSlice(op.getOperation(), &backwardSlice);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
getBackwardSlice(op.getOperation(), &backwardSlice, opt);
|
||||
op->removeAttr("async_agent");
|
||||
});
|
||||
for (auto op : backwardSlice) {
|
||||
|
||||
@@ -79,6 +79,7 @@ void materializeGetAgentIdOp(Operation *parentOp) {
|
||||
builder.setInsertionPoint(agentIdOp);
|
||||
Value globalRoleId = builder.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
int globalNumWarps = 0;
|
||||
SmallVector<Operation *> deprecatedOps;
|
||||
for (auto cmpOp : agentIdOp->getUsers()) {
|
||||
assert(isa<arith::CmpIOp>(cmpOp));
|
||||
for (auto u : cmpOp->getUsers()) {
|
||||
@@ -111,11 +112,14 @@ void materializeGetAgentIdOp(Operation *parentOp) {
|
||||
Value cond =
|
||||
builder.create<arith::AndIOp>(loc, lowerBound, upperBound);
|
||||
cmpOp->getResult(0).replaceAllUsesWith(cond);
|
||||
cmpOp->erase();
|
||||
deprecatedOps.push_back(cmpOp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Operation *cmpOp : deprecatedOps) {
|
||||
cmpOp->erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -145,39 +149,24 @@ LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) {
|
||||
}
|
||||
|
||||
Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op,
|
||||
bool skipFirstWait) {
|
||||
bool emptyBarrier) {
|
||||
// TODO: currently we only support one loop, no nested loop, while or
|
||||
// condition.
|
||||
auto loc = op->getLoc();
|
||||
auto forOp = op->getParentOfType<scf::ForOp>();
|
||||
if (!forOp) {
|
||||
return builder.create<arith::ConstantIntOp>(loc, skipFirstWait, 1);
|
||||
return builder.create<arith::ConstantIntOp>(loc, emptyBarrier, 1);
|
||||
}
|
||||
|
||||
auto defOp = op->getOperand(0).getDefiningOp();
|
||||
assert(isa<ttng::CreateTokenOp>(defOp) &&
|
||||
"mbarrier's definingOp is not createTokenOp");
|
||||
ttng::CreateTokenOp createTokenOp = dyn_cast<ttng::CreateTokenOp>(defOp);
|
||||
Value numStage =
|
||||
builder.create<arith::ConstantIntOp>(loc, createTokenOp.getNum(), 32);
|
||||
Value curStep = forOp.getBody()->getArguments().back();
|
||||
if (curStep.getType() == builder.getIndexType()) {
|
||||
curStep =
|
||||
builder.create<arith::IndexCastOp>(loc, numStage.getType(), curStep);
|
||||
// for (..., phase, pipelineIdx)
|
||||
unsigned numArgs = forOp.getBody()->getNumArguments();
|
||||
assert(numArgs > 2 && "Unexpected number of arguments");
|
||||
Value curPhase = forOp.getBody()->getArgument(numArgs - 2);
|
||||
if (emptyBarrier) {
|
||||
Value _1_1b = builder.create<arith::ConstantIntOp>(loc, 1, 1);
|
||||
curPhase = builder.create<mlir::arith::XOrIOp>(loc, curPhase, _1_1b);
|
||||
}
|
||||
Value curPhase = builder.create<arith::DivUIOp>(loc, curStep, numStage);
|
||||
if (skipFirstWait) {
|
||||
// If skipFirstWait, it waits for phaseBit 1
|
||||
Value _1 = builder.create<arith::ConstantIntOp>(loc, 1, 32);
|
||||
curPhase = builder.create<arith::AddIOp>(loc, curPhase, _1);
|
||||
}
|
||||
Value _2 = builder.create<arith::ConstantIntOp>(loc, 2, 32);
|
||||
// TODO: May use alternative methods of phaseBit calculation to avoid high
|
||||
// overhead of RemOp
|
||||
Value phaseBit = builder.create<arith::RemUIOp>(loc, curPhase, _2);
|
||||
Value _0 = builder.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, phaseBit,
|
||||
_0);
|
||||
return curPhase;
|
||||
}
|
||||
|
||||
int getTxBytes(ttng::InsertSliceAsyncV2Op load) {
|
||||
@@ -260,7 +249,7 @@ void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op,
|
||||
auto loc = op.getLoc();
|
||||
// The first producer_aquire should be met immediately, so initailly producer
|
||||
// skips the fisrt wait
|
||||
Value phase = getMBarrierPhaseBit(builder, op, 1);
|
||||
Value phase = getMBarrierPhaseBit(builder, op, true);
|
||||
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferEmpty, phase);
|
||||
assert(op.getOperation()->hasAttr("async_agent"));
|
||||
setAgentIds(waitOp, getAgentIds(op.getOperation()));
|
||||
@@ -296,7 +285,7 @@ void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op,
|
||||
void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op,
|
||||
Value bufferFull) {
|
||||
auto loc = op.getLoc();
|
||||
Value phase = getMBarrierPhaseBit(builder, op, 0);
|
||||
Value phase = getMBarrierPhaseBit(builder, op, false);
|
||||
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferFull, phase);
|
||||
assert(op.getOperation()->hasAttr("async_agent"));
|
||||
setAgentIds(waitOp, getAgentIds(op.getOperation()));
|
||||
@@ -530,6 +519,7 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
|
||||
builder.create<arith::ConstantIntOp>(loc, nameBarrierId - 1, 32);
|
||||
// Process mutex users
|
||||
int numUsers = 0;
|
||||
SmallVector<Operation *> deprecatedOps;
|
||||
for (Operation *user : createMutexOp.getResult().getUsers()) {
|
||||
numUsers++;
|
||||
assert(numUsers <= 2);
|
||||
@@ -543,14 +533,20 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
|
||||
Value barLeave = builder.create<arith::SelectOp>(
|
||||
loc, isRole0, namedBarrierId1, namedBarrierId0);
|
||||
builder.create<ttng::NamedBarrierArriveOp>(loc, barLeave, numThreads);
|
||||
} else
|
||||
} else {
|
||||
llvm_unreachable("Unexpected user of mutex");
|
||||
}
|
||||
deprecatedOps.push_back(user);
|
||||
}
|
||||
for (Operation *user : deprecatedOps) {
|
||||
user->erase();
|
||||
}
|
||||
nameBarrierId -= 2;
|
||||
nameBarrierIdEnd -= 2;
|
||||
createMutexOp.erase();
|
||||
});
|
||||
|
||||
parentOp->walk(
|
||||
[](ttng::CreateMutexOp createMutexOp) { createMutexOp.erase(); });
|
||||
}
|
||||
|
||||
void processLockOp(OpBuilder &builder, ttng::LockOp op) {
|
||||
@@ -587,6 +583,7 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
|
||||
OpBuilder builder(createMutexOp);
|
||||
|
||||
// Process mutex users
|
||||
SmallVector<Operation *> deprecatedOps;
|
||||
for (Operation *user : createMutexOp.getResult().getUsers()) {
|
||||
auto loc = user->getLoc();
|
||||
builder.setInsertionPoint(user);
|
||||
@@ -596,6 +593,9 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
|
||||
processUnlockOp(builder, op);
|
||||
else
|
||||
llvm_unreachable("Unexpected user of mutex");
|
||||
deprecatedOps.push_back(user);
|
||||
}
|
||||
for (Operation *user : deprecatedOps) {
|
||||
user->erase();
|
||||
}
|
||||
|
||||
|
||||
@@ -156,14 +156,20 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
|
||||
persistentForOp.getInitArgsMutable()
|
||||
.slice(persistentForOp.getInitArgs().size() - 1, 1)
|
||||
.assign(newIdx);
|
||||
auto yield =
|
||||
llvm::cast<scf::YieldOp>(persistentForOp.getBody()->getTerminator());
|
||||
auto idxPlusOneOp =
|
||||
yield->getOperand(yield->getNumOperands() - 1).getDefiningOp();
|
||||
assert(isa<arith::AddIOp>(idxPlusOneOp));
|
||||
assert(idxPlusOneOp->getOperand(0) ==
|
||||
persistentForOp.getBody()->getArgument(
|
||||
persistentForOp.getBody()->getNumArguments() - 1));
|
||||
|
||||
pipelineIdx = persistentForOp.getBody()->getArgument(
|
||||
persistentForOp.getBody()->getNumArguments() - 1);
|
||||
Operation *idxPlusOneOp = nullptr;
|
||||
for (OpOperand &v : pipelineIdx.getUses()) {
|
||||
if (isa<arith::AddIOp>(v.getOwner())) {
|
||||
idxPlusOneOp = v.getOwner();
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(idxPlusOneOp && "idxPlusOneOp should be arith::AddIOp");
|
||||
Operation *use = *idxPlusOneOp->getUsers().begin();
|
||||
assert(isa<scf::YieldOp>(use) || isa<arith::SelectOp>(use) ||
|
||||
isa<arith::CmpIOp>(use));
|
||||
idxPlusOneOp->setOperand(1, numRolesValue);
|
||||
|
||||
// Add operations at the start of persistentForOp
|
||||
@@ -213,45 +219,6 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
|
||||
unlockLocs[i] = op;
|
||||
}
|
||||
|
||||
// Update unlockLocs
|
||||
// ====================== IR after async launch dots ======================
|
||||
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
|
||||
// %3) {
|
||||
// * triton_nvidia_gpu.producer_wait arg2
|
||||
// * %5 = triton_nvidia_gpu.dot_async %4, %5
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 1}
|
||||
// * %6 = arith.cmpi sgt, arg0, %c0
|
||||
// * scf.if %6 {
|
||||
// * %7 = arith.subi arg2, c1
|
||||
// * triton_nvidia_gpu.consumer_release %7
|
||||
// * }
|
||||
// * %8 = arith.addi arg2, c1
|
||||
// * scf.yield %5, %8
|
||||
// * }
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 0}
|
||||
// * %9 = arith.subi %0#1, c1
|
||||
// * triton_nvidia_gpu.consumer_release %9
|
||||
// * =======================================================================
|
||||
// after async launch dots, there will be outstanding consumerReleaseOp after
|
||||
// ForOp. we should expend the unlockLocs from ForOp to the outstanding
|
||||
// consumerReleaseOp.
|
||||
for (int i = 0; i < numRoles; ++i) {
|
||||
Operation *unlockOp = unlockLocs[i];
|
||||
auto filter = [&](Operation *op) {
|
||||
return op->getBlock() == unlockOp->getBlock();
|
||||
};
|
||||
if (isa<scf::ForOp>(unlockOp)) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(unlockOp->getResults().back(), &slices, {filter});
|
||||
auto iter = llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
|
||||
});
|
||||
if (iter != slices.end()) {
|
||||
unlockLocs[i] = *iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only cases where all lock/unlock locations are in same level make sense.
|
||||
for (int i = 1; i < numRoles; ++i) {
|
||||
if (lockLocs[i]->getParentOp() != lockLocs[i - 1]->getParentOp() ||
|
||||
@@ -281,6 +248,54 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
|
||||
else
|
||||
lockLocs[i] = unlockLocs[prevTypeIds[i]];
|
||||
}
|
||||
|
||||
// Update lockLocs
|
||||
// ====================== IR after async launch dots ======================
|
||||
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
|
||||
// %3) {
|
||||
// * triton_nvidia_gpu.producer_wait arg2
|
||||
// * %5 = triton_nvidia_gpu.dot_async %4, %5
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 1}
|
||||
// * %6 = arith.cmpi sgt, arg0, %c0
|
||||
// * scf.if %6 {
|
||||
// * %7 = arith.subi arg2, c1
|
||||
// * triton_nvidia_gpu.consumer_release %7
|
||||
// * }
|
||||
// * %8 = arith.addi arg2, c1
|
||||
// * scf.yield %5, %8
|
||||
// * }
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 0}
|
||||
// * ...
|
||||
// * triton_nvidia_gpu.consumer_release ..
|
||||
// * =======================================================================
|
||||
// after async launch dots, there will be outstanding consumerReleaseOp after
|
||||
// ForOp. we should set the epilogue lockLocs after the outstanding
|
||||
// consumerReleaseOp.
|
||||
for (int i = 0; i < numRoles; ++i) {
|
||||
Operation *lockOp = lockLocs[i];
|
||||
if (isa<scf::ForOp>(lockOp)) {
|
||||
Operation *loc = nullptr;
|
||||
unsigned numOutstandingConsumerRelease = 0;
|
||||
for (auto v : lockOp->getResults()) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(v, &slices);
|
||||
auto iter = llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
|
||||
});
|
||||
if (iter != slices.end()) {
|
||||
numOutstandingConsumerRelease++;
|
||||
loc = *iter;
|
||||
}
|
||||
}
|
||||
assert(numOutstandingConsumerRelease <= 1 &&
|
||||
"should have only one outstanding "
|
||||
"consumerReleaseOp after "
|
||||
"async launch dots");
|
||||
if (loc)
|
||||
lockLocs[i] = loc;
|
||||
}
|
||||
}
|
||||
|
||||
// lock
|
||||
for (int i = 0; i < numRoles; ++i) {
|
||||
builder.setInsertionPointAfter(lockLocs[i]);
|
||||
|
||||
@@ -129,11 +129,12 @@ DenseMap<Channel *, Value> createBuffer(const SmallVector<Channel *> &channels,
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// appendPipelineIdxToLoopArgs
|
||||
// createNewLoops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
|
||||
scf::ForOp &parentForOp) {
|
||||
// for(...) -> for(..., pipelineIdx)
|
||||
scf::ForOp createNewPersistentLoop(scf::ForOp forOp, int numStages,
|
||||
scf::ForOp &parentForOp) {
|
||||
auto loc = forOp.getLoc();
|
||||
Block *body = forOp.getBody();
|
||||
|
||||
@@ -200,6 +201,117 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
// for(...) -> for(..., phase, pipelineIdx)
|
||||
scf::ForOp createNewMathLoop(scf::ForOp forOp, int numStages,
|
||||
scf::ForOp &parentForOp) {
|
||||
auto loc = forOp.getLoc();
|
||||
Block *body = forOp.getBody();
|
||||
|
||||
// The agentId set of pipelineIdx is the union of agentId sets of all ops in
|
||||
// the for loop
|
||||
OpBuilderWithAgentIds builder(forOp.getContext());
|
||||
builder.setAgentIdsFromArray(collectAgentIds(forOp));
|
||||
|
||||
builder.setInsertionPoint(forOp);
|
||||
Value numStagesVal =
|
||||
builder.createWithAgentIds<arith::ConstantIntOp>(loc, numStages, 32);
|
||||
|
||||
// 0. Append pipelineIdx to block arguments
|
||||
Value phase =
|
||||
body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc);
|
||||
Value pipelineIdx =
|
||||
body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc);
|
||||
|
||||
// 1. prepare index and phase for next iteration
|
||||
// nextIdx = curIdx + 1
|
||||
// nextPhase = ((nextIdx < numStages && curPhase) || (nextIdx >= numStages &&
|
||||
// curPhase^1))
|
||||
// nextIdx = nextIdx >= numStages ? 0 : nextIdx
|
||||
auto yieldOp = llvm::cast<scf::YieldOp>(body->getTerminator());
|
||||
builder.setInsertionPoint(yieldOp);
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value _1_1b = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 1);
|
||||
// generate index for next iter
|
||||
Value nextPipelineIdx =
|
||||
builder.createWithAgentIds<arith::AddIOp>(loc, pipelineIdx, one);
|
||||
Value pipelineGECond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::uge, nextPipelineIdx, numStagesVal);
|
||||
Value pipelineLTCond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ult, nextPipelineIdx, numStagesVal);
|
||||
Value cyclePipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, nextPipelineIdx, numStagesVal);
|
||||
nextPipelineIdx = builder.createWithAgentIds<mlir::arith::SelectOp>(
|
||||
loc, pipelineGECond, cyclePipelineIdx, nextPipelineIdx);
|
||||
// generate phase for next iter
|
||||
Value flipPhase =
|
||||
builder.createWithAgentIds<mlir::arith::XOrIOp>(loc, phase, _1_1b);
|
||||
Value cond0 = builder.createWithAgentIds<mlir::arith::AndIOp>(
|
||||
loc, pipelineGECond, flipPhase);
|
||||
Value cond1 = builder.createWithAgentIds<mlir::arith::AndIOp>(
|
||||
loc, pipelineLTCond, phase);
|
||||
Value nextPhase =
|
||||
builder.createWithAgentIds<mlir::arith::OrIOp>(loc, cond0, cond1);
|
||||
|
||||
// 2. Append pipelineIdx to yield operands
|
||||
yieldOp->insertOperands(yieldOp.getNumOperands(),
|
||||
{nextPhase, nextPipelineIdx});
|
||||
|
||||
// 3. create newLoopArgs
|
||||
SmallVector<Value> newLoopArgs;
|
||||
for (auto operand : forOp.getInitArgs())
|
||||
newLoopArgs.push_back(operand);
|
||||
|
||||
builder.setInsertionPoint(forOp);
|
||||
Value initPipelineIdx, initEmptyIdx, initPhase;
|
||||
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
if (parentForOp) {
|
||||
// Make sure prior pipelineIdx is inserted in the end of parentForOp
|
||||
initPipelineIdx = parentForOp.getBody()->getArguments().back();
|
||||
Value numSteps = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, forOp.getUpperBound(), forOp.getLowerBound());
|
||||
numSteps = builder.createWithAgentIds<arith::AddIOp>(loc, numSteps,
|
||||
forOp.getStep());
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value two = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 2, 32);
|
||||
numSteps = builder.createWithAgentIds<arith::SubIOp>(loc, numSteps, one);
|
||||
numSteps = builder.createWithAgentIds<arith::DivUIOp>(loc, numSteps,
|
||||
forOp.getStep());
|
||||
// initPipelineIdx = (parentForOp.pipelineIdx * numSteps) % numStages
|
||||
// initPhase = ((parentForOp.pipelineIdx * numSteps) / numStages) & 1
|
||||
initPipelineIdx = builder.createWithAgentIds<arith::MulIOp>(
|
||||
loc, initPipelineIdx, numSteps);
|
||||
Value pipelineIdx = builder.createWithAgentIds<arith::DivUIOp>(
|
||||
loc, initPipelineIdx, numStagesVal);
|
||||
initPipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, initPipelineIdx,
|
||||
builder.createWithAgentIds<arith::MulIOp>(loc, pipelineIdx,
|
||||
numStagesVal));
|
||||
pipelineIdx =
|
||||
builder.createWithAgentIds<arith::AndIOp>(loc, pipelineIdx, one);
|
||||
initPhase = builder.createWithAgentIds<arith::TruncIOp>(
|
||||
loc, builder.getI1Type(), pipelineIdx);
|
||||
} else {
|
||||
// phase init to false and pipelineIdx init to 0
|
||||
initPipelineIdx = zero;
|
||||
initPhase = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 1);
|
||||
}
|
||||
newLoopArgs.append({initPhase, initPipelineIdx});
|
||||
|
||||
// 4. Create newForOp and take the region of forOp
|
||||
auto newForOp = builder.createWithAgentIds<scf::ForOp>(
|
||||
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
|
||||
newLoopArgs);
|
||||
newForOp.getRegion().takeBody(forOp.getRegion());
|
||||
|
||||
// 5. Replace forOp with newForOp
|
||||
for (unsigned i = 0; i < forOp.getNumResults(); ++i)
|
||||
forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i));
|
||||
forOp.erase();
|
||||
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// appendPipelineIdxArgs
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -217,7 +329,22 @@ void appendPipelineIdxArgs(SmallVector<Operation *> &backbone, int numStages) {
|
||||
|
||||
for (auto &op : orderedForOps) {
|
||||
scf::ForOp parentForOp = op->getParentOfType<scf::ForOp>();
|
||||
auto newForOp = appendPipelineIdxToLoopArgs(op, numStages, parentForOp);
|
||||
scf::ForOp newForOp;
|
||||
bool hasDotOp = false;
|
||||
for (Operation &subOp : *op.getBody()) {
|
||||
if (isa<triton::DotOp>(&subOp)) {
|
||||
hasDotOp = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasDotOp) {
|
||||
// for(...) -> for(..., phase, pipelineIdx)
|
||||
newForOp = createNewMathLoop(op, numStages, parentForOp);
|
||||
} else {
|
||||
// for(...) -> for(..., pipelineIdx)
|
||||
newForOp = createNewPersistentLoop(op, numStages, parentForOp);
|
||||
}
|
||||
auto backboneForItr =
|
||||
std::find(backbone.begin(), backbone.end(), op.getOperation());
|
||||
if (backboneForItr != backbone.end()) {
|
||||
@@ -688,8 +815,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
}
|
||||
builder.setAgentIdsFromArray(agentsPC);
|
||||
Value pipelineIdx;
|
||||
Value numStagesVal = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
headProducer->getLoc(), numStages, 32);
|
||||
if (auto forOp = headProducer->getParentOfType<scf::ForOp>()) {
|
||||
pipelineIdx = forOp.getBody()->getArguments().back();
|
||||
} else {
|
||||
@@ -700,10 +825,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
|
||||
// insert ProducerAcquireOp
|
||||
builder.setInsertionPoint(headProducer);
|
||||
if (headProducer->getParentOfType<scf::ForOp>()) {
|
||||
pipelineIdx = builder.createWithAgentIds<arith::RemSIOp>(
|
||||
headProducer->getLoc(), pipelineIdx, numStagesVal);
|
||||
}
|
||||
builder.setAgentIdsFromArray(agentP);
|
||||
builder.createWithAgentIds<ttng::ProducerAcquireOp>(headProducer->getLoc(),
|
||||
token, pipelineIdx);
|
||||
@@ -738,7 +859,8 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dot.replaceAllUsesWith(dotAsync.getResult());
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(loc, 1);
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
loc, dotAsync.getResult(), 1);
|
||||
|
||||
// 1. insert ConsumerReleaseOp for DotAsyncOps
|
||||
Value cond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
@@ -747,31 +869,43 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
auto ifOp =
|
||||
builder.createWithAgentIds<scf::IfOp>(loc, ArrayRef<Type>{}, cond,
|
||||
/*hasElse*/ false);
|
||||
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
|
||||
builder.setInsertionPointToStart(ifOp.thenBlock());
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
headConsumer->getLoc(), 1, 32);
|
||||
auto oriIdx = forOp.getBody()->getArguments().back();
|
||||
Value consumerReleaseIdx =
|
||||
builder.createWithAgentIds<arith::SubIOp>(loc, oriIdx, one);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
|
||||
loc, consumerReleaseIdx, numStagesVal);
|
||||
Value consumerReleaseIdx = forOp.getBody()->getArguments().back();
|
||||
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
loc, numStages - 1, 32);
|
||||
Value consumerReleaseIdxMinusOne =
|
||||
builder.createWithAgentIds<arith::SubIOp>(loc, consumerReleaseIdx,
|
||||
one);
|
||||
cond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
|
||||
loc, cond, lastStage, consumerReleaseIdxMinusOne);
|
||||
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
|
||||
consumerReleaseIdx);
|
||||
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
|
||||
|
||||
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
|
||||
0);
|
||||
unsigned resultIndex = dotAsync->getUses().begin()->getOperandNumber();
|
||||
Value result = forOp->getResult(resultIndex);
|
||||
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
|
||||
|
||||
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
|
||||
Value one_ = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
headConsumer->getLoc(), 1, 32);
|
||||
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
loc, numStages - 1, 32);
|
||||
consumerReleaseIdx = forOp.getResults().back();
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, consumerReleaseIdx, one_);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
|
||||
loc, consumerReleaseIdx, numStagesVal);
|
||||
consumerReleaseIdxMinusOne = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, consumerReleaseIdx, one);
|
||||
cond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
|
||||
loc, cond, lastStage, consumerReleaseIdxMinusOne);
|
||||
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
|
||||
consumerReleaseIdx);
|
||||
dotOp->erase();
|
||||
|
||||
@@ -14,7 +14,6 @@ add_mlir_translation_library(TritonLLVMIR
|
||||
PUBLIC
|
||||
MLIRArithToLLVM
|
||||
MLIRBuiltinToLLVMIRTranslation
|
||||
MLIRExecutionEngineUtils
|
||||
MLIRIndexToLLVM
|
||||
MLIRIR
|
||||
MLIRLLVMDialect
|
||||
|
||||
@@ -44,7 +44,8 @@ static bool findAndReplace(std::string &str, const std::string &begin,
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version,
|
||||
bool enable_fp_fusion) {
|
||||
// LLVM version in use may not officially support target hardware.
|
||||
// Supported versions for LLVM 14 are here:
|
||||
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
|
||||
@@ -84,13 +85,15 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||
auto target =
|
||||
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
|
||||
llvm::TargetOptions opt;
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
if (enable_fp_fusion)
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
opt.UnsafeFPMath = false;
|
||||
opt.NoInfsFPMath = false;
|
||||
opt.NoNaNsFPMath = true;
|
||||
opt.TrapUnreachable = true;
|
||||
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
|
||||
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
||||
std::nullopt, llvm::CodeGenOpt::Aggressive)};
|
||||
std::nullopt, llvm::CodeGenOptLevel::Aggressive)};
|
||||
// set data layout
|
||||
if (layout.empty())
|
||||
module.setDataLayout(machine->createDataLayout());
|
||||
@@ -106,7 +109,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||
llvm::legacy::PassManager pass;
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, pstream, nullptr,
|
||||
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
llvm::CodeGenFileType::AssemblyFile);
|
||||
pass.run(module);
|
||||
}
|
||||
// post-process
|
||||
|
||||
Reference in New Issue
Block a user