mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit 'cb3d79a185e40c9d8a579bea07747a8a8d157d52' into ifu-231117
Conflicts: lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp python/setup.py python/test/unit/language/assert_helper.py python/test/unit/operators/test_flash_attention.py python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/language/semantic.py python/triton/runtime/autotuner.py python/triton/runtime/jit.py python/tutorials/03-matrix-multiplication.py python/tutorials/05-layer-norm.py python/tutorials/06-fused-attention.py python/tutorials/11-grouped-gemm.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -28,6 +28,8 @@ static CUtensorMapDataType getCUtensorMapDataType(Type ty) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if (ty.isF32()) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if (ty.getIntOrFloatBitWidth() == 8) {
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op");
|
||||
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
@@ -930,6 +932,11 @@ private:
|
||||
return -1 -
|
||||
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
if (!isa<BlockArgument>(v) &&
|
||||
!isa<mlir::UnrealizedConversionCastOp, arith::ExtSIOp>(
|
||||
v.getDefiningOp()))
|
||||
llvm::report_fatal_error(
|
||||
"Operand of `MakeTensorPtrOp` is not the function's argument");
|
||||
if (v.getDefiningOp() &&
|
||||
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
@@ -1095,40 +1102,81 @@ struct AtomicCASOpConversion
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: valueTy;
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
|
||||
// vec = 1 for scalar
|
||||
auto vec = getVectorSize(op.getPtr());
|
||||
// tensor
|
||||
if (TensorTy) {
|
||||
auto valTy = op.getVal().getType().cast<RankedTensorType>();
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
}
|
||||
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
Value casPtr = ptrElements[0];
|
||||
Value casCmp = cmpElements[0];
|
||||
Value casVal = valElements[0];
|
||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||
Value casVal = undef(vecTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
Value iiVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
||||
casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal);
|
||||
}
|
||||
|
||||
PTXBuilder ptxBuilderAtomicCAS;
|
||||
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r", /*init=*/true);
|
||||
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
|
||||
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||
std::string semStr;
|
||||
llvm::raw_string_ostream os(semStr);
|
||||
os << op.getSem();
|
||||
atom.global().o(semStr).o("cas").o("b32");
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
Value casPtr = ptrElements[i];
|
||||
Value casCmp = cmpElements[i];
|
||||
casVal = valElements[i];
|
||||
PTXBuilder ptxBuilderAtomicCAS;
|
||||
std::string tyId = valueElemNBits * vec == 64
|
||||
? "l"
|
||||
: (valueElemNBits * vec == 32 ? "r" : "h");
|
||||
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true);
|
||||
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
|
||||
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, tyId);
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, tyId);
|
||||
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||
auto sTy = "b" + std::to_string(valueElemNBits);
|
||||
std::string semStr;
|
||||
llvm::raw_string_ostream os(semStr);
|
||||
os << op.getSem();
|
||||
auto scope = stringifyMemSyncScope(op.getScope()).str();
|
||||
atom.global().o(semStr).o(scope).o("cas").o(sTy);
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
|
||||
PTXBuilder ptxBuilderStore;
|
||||
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
|
||||
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o("b32");
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
Value ret = load(atomPtr);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
rewriter.replaceOp(op, {ret});
|
||||
if (TensorTy) {
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i + ii] =
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
|
||||
}
|
||||
} else {
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
// Only threads with mask = True store the result
|
||||
PTXBuilder ptxBuilderStore;
|
||||
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
|
||||
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o(sTy);
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
Value ret = load(atomPtr);
|
||||
createBarrier(rewriter, loc, numCTAs);
|
||||
rewriter.replaceOp(op, {ret});
|
||||
}
|
||||
}
|
||||
|
||||
if (TensorTy) {
|
||||
Type structTy = getTypeConverter()->convertType(TensorTy);
|
||||
Value resultStruct = getTypeConverter()->packLLElements(
|
||||
loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
@@ -1360,7 +1408,8 @@ struct AtomicRMWOpConversion
|
||||
auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
|
||||
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
|
||||
|
||||
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu");
|
||||
auto scope = stringifyMemSyncScope(op.getScope()).str();
|
||||
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope);
|
||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||
auto sBits = std::to_string(valueElemNBits);
|
||||
switch (atomicRmwAttr) {
|
||||
@@ -2001,6 +2050,11 @@ private:
|
||||
return -1 -
|
||||
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
if (!isa<BlockArgument>(v) &&
|
||||
!isa<mlir::UnrealizedConversionCastOp, arith::ExtSIOp>(
|
||||
v.getDefiningOp()))
|
||||
llvm::report_fatal_error(
|
||||
"Operand of `MakeTensorPtrOp` is not the function's argument");
|
||||
if (v.getDefiningOp() &&
|
||||
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
|
||||
return getArgIdx(v.getDefiningOp()->getOperand(0));
|
||||
|
||||
Reference in New Issue
Block a user