mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit 'ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33' into ifu-rebase-again
Conflicts: .gitignore .gitmodules README.md bin/triton-translate.cpp include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td include/triton/Target/AMDGCN/AMDGCNTranslation.h include/triton/Target/HSACO/HSACOTranslation.h lib/Analysis/Allocation.cpp lib/Analysis/Utility.cpp lib/Conversion/TritonGPUToLLVM/CMakeLists.txt lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/Utility.cpp lib/Conversion/TritonGPUToLLVM/Utility.h lib/Dialect/TritonGPU/IR/Dialect.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp lib/Target/HSACO/CMakeLists.txt lib/Target/HSACO/HSACOTranslation.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/src/triton.cc python/test/unit/language/test_core.py python/test/unit/operators/test_flash_attention.py python/triton/compiler/compiler.py python/triton/compiler/make_launcher.py python/triton/language/semantic.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py python/tutorials/11-grouped-gemm.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -21,6 +21,19 @@ using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
static CUtensorMapDataType getCUtensorMapDataType(Type ty) {
|
||||
if (ty.isF16()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if (ty.isBF16()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if (ty.isF32()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op");
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
}
|
||||
}
|
||||
|
||||
// Contains some helper functions for both Load and Store conversions.
|
||||
struct LoadStoreConversionBase {
|
||||
explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass)
|
||||
@@ -804,7 +817,7 @@ struct StoreAsyncOpConversion
|
||||
typeConverter->convertType(rewriter.getI8Type()), 3);
|
||||
|
||||
auto threadId = getThreadId(rewriter, loc);
|
||||
Value pred = icmp_eq(urem(threadId, i32_val(32)), i32_val(0));
|
||||
Value pred = int_val(1, 1);
|
||||
|
||||
auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter,
|
||||
dst.getType());
|
||||
@@ -912,17 +925,6 @@ struct StoreAsyncOpConversion
|
||||
}
|
||||
|
||||
private:
|
||||
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
|
||||
if (ty.isF16()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if (ty.isF32()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported elemTy for StoreAsyncOp");
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getArgIdx(Value v) const {
|
||||
if (auto op = v.getDefiningOp<mlir::arith::ConstantOp>()) {
|
||||
return -1 -
|
||||
@@ -969,6 +971,18 @@ private:
|
||||
const TensorPtrMapT *tensorPtrMap;
|
||||
};
|
||||
|
||||
namespace {
|
||||
void createBarrier(ConversionPatternRewriter &rewriter, Location loc,
|
||||
int numCTAs) {
|
||||
if (numCTAs == 1) {
|
||||
barrier();
|
||||
} else {
|
||||
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
|
||||
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
struct AtomicCASOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
|
||||
public LoadStoreConversionBase {
|
||||
@@ -1060,6 +1074,10 @@ struct AtomicCASOpConversion
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for AtomicCASOp");
|
||||
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
|
||||
|
||||
Value llPtr = adaptor.getPtr();
|
||||
Value llCmp = adaptor.getCmp();
|
||||
Value llVal = adaptor.getVal();
|
||||
@@ -1097,7 +1115,7 @@ struct AtomicCASOpConversion
|
||||
atom.global().o(semStr).o("cas").o("b32");
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
barrier();
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
|
||||
PTXBuilder ptxBuilderStore;
|
||||
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
|
||||
@@ -1107,9 +1125,9 @@ struct AtomicCASOpConversion
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
barrier();
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
Value ret = load(atomPtr);
|
||||
barrier();
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
rewriter.replaceOp(op, {ret});
|
||||
return success();
|
||||
}
|
||||
@@ -1279,7 +1297,11 @@ struct AtomicRMWOpConversion
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
//
|
||||
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for AtomicRMWOp");
|
||||
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
|
||||
|
||||
auto atomicRmwAttr = op.getAtomicRmwOp();
|
||||
|
||||
Value val = op.getVal();
|
||||
@@ -1352,7 +1374,7 @@ struct AtomicRMWOpConversion
|
||||
sTy = "b" + sBits;
|
||||
break;
|
||||
case RMWOp::ADD:
|
||||
sTy = "s" + sBits;
|
||||
sTy = "u" + sBits;
|
||||
break;
|
||||
case RMWOp::FADD:
|
||||
rmwOp = "add";
|
||||
@@ -1410,9 +1432,9 @@ struct AtomicRMWOpConversion
|
||||
auto *valOpr = ptxBuilderStore.newOperand(old, tyId);
|
||||
storeShared(ptrOpr, valOpr).predicate(rmwMask);
|
||||
ptxBuilderStore.launch(rewriter, loc, void_ty(ctx));
|
||||
barrier();
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
Value ret = load(atomPtr);
|
||||
barrier();
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
rewriter.replaceOp(op, {ret});
|
||||
}
|
||||
}
|
||||
@@ -1980,17 +2002,6 @@ private:
|
||||
return bcastMask;
|
||||
}
|
||||
|
||||
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
|
||||
if (ty.isF16()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if (ty.isF32()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op");
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getArgIdx(Value v) const {
|
||||
if (auto op = v.getDefiningOp<mlir::arith::ConstantOp>()) {
|
||||
return -1 -
|
||||
|
||||
Reference in New Issue
Block a user