Merge remote-tracking branch 'oai/main' into ifu230620

Conflicts:
	include/triton/Conversion/TritonToTritonGPU/Passes.td
	include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	python/test/unit/language/assert_helper.py
	python/triton/compiler/compiler.py
	python/triton/runtime/jit.py
	python/triton/tools/aot.py
	test/Conversion/triton_to_tritongpu.mlir
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-06-29 21:47:27 +00:00
80 changed files with 2508 additions and 959 deletions

View File

@@ -402,7 +402,14 @@ struct StoreOpConversion
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
auto &ptxStoreInstr =
ptxBuilder.create<>("st")->global().v(nWords).b(width);
ptxBuilder.create<>("st")
->global()
.o("L1::evict_first",
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
.v(nWords)
.b(width);
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
@@ -528,11 +535,6 @@ struct AtomicCASOpConversion
: valueTy;
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
Value mask = getMask(valueTy, rewriter, loc);
PTXBuilder ptxBuilderMemfence;
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfence();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
@@ -546,7 +548,10 @@ struct AtomicCASOpConversion
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
atom.global().o("cas").o("b32");
std::string semStr;
llvm::raw_string_ostream os(semStr);
os << op.getSem();
atom.global().o(semStr).o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();
@@ -557,8 +562,8 @@ struct AtomicCASOpConversion
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(mask);
auto ASMReturnTy = void_ty(ctx);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
barrier();
Value ret = load(atomPtr);
barrier();
@@ -731,7 +736,7 @@ struct AtomicRMWOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
//
auto atomicRmwAttr = op.getAtomicRmwOp();
Value val = op.getVal();
@@ -832,7 +837,10 @@ struct AtomicRMWOpConversion
default:
return failure();
}
atom.o(rmwOp).o(sTy);
std::string semStr;
llvm::raw_string_ostream os(semStr);
os << op.getSem();
atom.o(semStr).o(rmwOp).o(sTy);
if (tensorTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto retType = vec == 1 ? valueElemTy : vecTy;
@@ -842,13 +850,13 @@ struct AtomicRMWOpConversion
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
}
} else {
PTXBuilder ptxBuilderMemfence;
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
if (op->user_begin() == op->user_end()) {
rewriter.replaceOp(op, {old});
return success();
}
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
// Only threads with rmwMask = True store the result