Merge commit 'cb3d79a185e40c9d8a579bea07747a8a8d157d52' into ifu-231117

Conflicts:
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Dialect/TritonGPU/IR/Dialect.cpp
	python/setup.py
	python/test/unit/language/assert_helper.py
	python/test/unit/operators/test_flash_attention.py
	python/test/unit/runtime/test_subproc.py
	python/triton/compiler/compiler.py
	python/triton/language/semantic.py
	python/triton/runtime/autotuner.py
	python/triton/runtime/jit.py
	python/tutorials/03-matrix-multiplication.py
	python/tutorials/05-layer-norm.py
	python/tutorials/06-fused-attention.py
	python/tutorials/11-grouped-gemm.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-17 20:42:12 +00:00
179 changed files with 10116 additions and 6835 deletions

View File

@@ -28,6 +28,8 @@ static CUtensorMapDataType getCUtensorMapDataType(Type ty) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (ty.isF32()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else if (ty.getIntOrFloatBitWidth() == 8) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else {
llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op");
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
@@ -930,6 +932,11 @@ private:
return -1 -
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
}
if (!isa<BlockArgument>(v) &&
!isa<mlir::UnrealizedConversionCastOp, arith::ExtSIOp>(
v.getDefiningOp()))
llvm::report_fatal_error(
"Operand of `MakeTensorPtrOp` is not the function's argument");
if (v.getDefiningOp() &&
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
return getArgIdx(v.getDefiningOp()->getOperand(0));
@@ -1095,40 +1102,81 @@ struct AtomicCASOpConversion
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
: valueTy;
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
// vec = 1 for scalar
auto vec = getVectorSize(op.getPtr());
// tensor
if (TensorTy) {
auto valTy = op.getVal().getType().cast<RankedTensorType>();
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
}
Value mask = getMask(valueTy, rewriter, loc);
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
Value casPtr = ptrElements[0];
Value casCmp = cmpElements[0];
Value casVal = valElements[0];
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value casVal = undef(vecTy);
for (int ii = 0; ii < vec; ++ii) {
Value iiVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), ii);
casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal);
}
PTXBuilder ptxBuilderAtomicCAS;
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r", /*init=*/true);
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
std::string semStr;
llvm::raw_string_ostream os(semStr);
os << op.getSem();
atom.global().o(semStr).o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
createBarrier(rewriter, loc, numCTAs);
Value casPtr = ptrElements[i];
Value casCmp = cmpElements[i];
casVal = valElements[i];
PTXBuilder ptxBuilderAtomicCAS;
std::string tyId = valueElemNBits * vec == 64
? "l"
: (valueElemNBits * vec == 32 ? "r" : "h");
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true);
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, tyId);
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, tyId);
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
auto sTy = "b" + std::to_string(valueElemNBits);
std::string semStr;
llvm::raw_string_ostream os(semStr);
os << op.getSem();
auto scope = stringifyMemSyncScope(op.getScope()).str();
atom.global().o(semStr).o(scope).o("cas").o(sTy);
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
PTXBuilder ptxBuilderStore;
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(mask);
auto ASMReturnTy = void_ty(ctx);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
createBarrier(rewriter, loc, numCTAs);
Value ret = load(atomPtr);
createBarrier(rewriter, loc, numCTAs);
rewriter.replaceOp(op, {ret});
if (TensorTy) {
auto retType = vec == 1 ? valueElemTy : vecTy;
auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i + ii] =
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
}
} else {
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
createBarrier(rewriter, loc, numCTAs);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
// Only threads with mask = True store the result
PTXBuilder ptxBuilderStore;
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o(sTy);
st(dstOprStore, valOprStore).predicate(mask);
auto ASMReturnTy = void_ty(ctx);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
createBarrier(rewriter, loc, numCTAs);
Value ret = load(atomPtr);
createBarrier(rewriter, loc, numCTAs);
rewriter.replaceOp(op, {ret});
}
}
if (TensorTy) {
Type structTy = getTypeConverter()->convertType(TensorTy);
Value resultStruct = getTypeConverter()->packLLElements(
loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
}
return success();
}
#endif // USE_ROCM
@@ -1360,7 +1408,8 @@ struct AtomicRMWOpConversion
auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu");
auto scope = stringifyMemSyncScope(op.getScope()).str();
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope);
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
auto sBits = std::to_string(valueElemNBits);
switch (atomicRmwAttr) {
@@ -2001,6 +2050,11 @@ private:
return -1 -
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
}
if (!isa<BlockArgument>(v) &&
!isa<mlir::UnrealizedConversionCastOp, arith::ExtSIOp>(
v.getDefiningOp()))
llvm::report_fatal_error(
"Operand of `MakeTensorPtrOp` is not the function's argument");
if (v.getDefiningOp() &&
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
return getArgIdx(v.getDefiningOp()->getOperand(0));