mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge remote-tracking branch 'oai/main' into ifu230620
Conflicts: include/triton/Conversion/TritonToTritonGPU/Passes.td include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp python/test/unit/language/assert_helper.py python/triton/compiler/compiler.py python/triton/runtime/jit.py python/triton/tools/aot.py test/Conversion/triton_to_tritongpu.mlir test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -402,7 +402,14 @@ struct StoreOpConversion
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
auto &ptxStoreInstr =
|
||||
ptxBuilder.create<>("st")->global().v(nWords).b(width);
|
||||
ptxBuilder.create<>("st")
|
||||
->global()
|
||||
.o("L1::evict_first",
|
||||
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||
.o("L1::evict_last",
|
||||
op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
.v(nWords)
|
||||
.b(width);
|
||||
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
||||
|
||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||
@@ -528,11 +535,6 @@ struct AtomicCASOpConversion
|
||||
: valueTy;
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||
memfence();
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
@@ -546,7 +548,10 @@ struct AtomicCASOpConversion
|
||||
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||
atom.global().o("cas").o("b32");
|
||||
std::string semStr;
|
||||
llvm::raw_string_ostream os(semStr);
|
||||
os << op.getSem();
|
||||
atom.global().o(semStr).o("cas").o("b32");
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
barrier();
|
||||
@@ -557,8 +562,8 @@ struct AtomicCASOpConversion
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o("b32");
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
barrier();
|
||||
@@ -731,7 +736,7 @@ struct AtomicRMWOpConversion
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
//
|
||||
auto atomicRmwAttr = op.getAtomicRmwOp();
|
||||
|
||||
Value val = op.getVal();
|
||||
@@ -832,7 +837,10 @@ struct AtomicRMWOpConversion
|
||||
default:
|
||||
return failure();
|
||||
}
|
||||
atom.o(rmwOp).o(sTy);
|
||||
std::string semStr;
|
||||
llvm::raw_string_ostream os(semStr);
|
||||
os << op.getSem();
|
||||
atom.o(semStr).o(rmwOp).o(sTy);
|
||||
if (tensorTy) {
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
@@ -842,13 +850,13 @@ struct AtomicRMWOpConversion
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
|
||||
}
|
||||
} else {
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||
memfenc();
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||
if (op->user_begin() == op->user_end()) {
|
||||
rewriter.replaceOp(op, {old});
|
||||
return success();
|
||||
}
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
// Only threads with rmwMask = True store the result
|
||||
|
||||
Reference in New Issue
Block a user