Merge commit 'ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33' into ifu-rebase-again

Conflicts:
	.gitignore
	.gitmodules
	README.md
	bin/triton-translate.cpp
	include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
	include/triton/Target/AMDGCN/AMDGCNTranslation.h
	include/triton/Target/HSACO/HSACOTranslation.h
	lib/Analysis/Allocation.cpp
	lib/Analysis/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
	lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/Utility.h
	lib/Dialect/TritonGPU/IR/Dialect.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	lib/Target/HSACO/CMakeLists.txt
	lib/Target/HSACO/HSACOTranslation.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/src/triton.cc
	python/test/unit/language/test_core.py
	python/test/unit/operators/test_flash_attention.py
	python/triton/compiler/compiler.py
	python/triton/compiler/make_launcher.py
	python/triton/language/semantic.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	python/tutorials/11-grouped-gemm.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-06 23:10:10 +00:00
161 changed files with 6530 additions and 3905 deletions

View File

@@ -21,6 +21,19 @@ using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
static CUtensorMapDataType getCUtensorMapDataType(Type ty) {
if (ty.isF16()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if (ty.isBF16()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (ty.isF32()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else {
llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op");
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
}
}
// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass)
@@ -804,7 +817,7 @@ struct StoreAsyncOpConversion
typeConverter->convertType(rewriter.getI8Type()), 3);
auto threadId = getThreadId(rewriter, loc);
Value pred = icmp_eq(urem(threadId, i32_val(32)), i32_val(0));
Value pred = int_val(1, 1);
auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter,
dst.getType());
@@ -912,17 +925,6 @@ struct StoreAsyncOpConversion
}
private:
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
if (ty.isF16()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if (ty.isF32()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else {
llvm::report_fatal_error("Unsupported elemTy for StoreAsyncOp");
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
}
}
unsigned getArgIdx(Value v) const {
if (auto op = v.getDefiningOp<mlir::arith::ConstantOp>()) {
return -1 -
@@ -969,6 +971,18 @@ private:
const TensorPtrMapT *tensorPtrMap;
};
namespace {
void createBarrier(ConversionPatternRewriter &rewriter, Location loc,
int numCTAs) {
if (numCTAs == 1) {
barrier();
} else {
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
}
}
} // namespace
struct AtomicCASOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
public LoadStoreConversionBase {
@@ -1060,6 +1074,10 @@ struct AtomicCASOpConversion
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for AtomicCASOp");
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
Value llPtr = adaptor.getPtr();
Value llCmp = adaptor.getCmp();
Value llVal = adaptor.getVal();
@@ -1097,7 +1115,7 @@ struct AtomicCASOpConversion
atom.global().o(semStr).o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();
createBarrier(rewriter, loc, numCTAs);
PTXBuilder ptxBuilderStore;
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
@@ -1107,9 +1125,9 @@ struct AtomicCASOpConversion
st(dstOprStore, valOprStore).predicate(mask);
auto ASMReturnTy = void_ty(ctx);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
barrier();
createBarrier(rewriter, loc, numCTAs);
Value ret = load(atomPtr);
barrier();
createBarrier(rewriter, loc, numCTAs);
rewriter.replaceOp(op, {ret});
return success();
}
@@ -1279,7 +1297,11 @@ struct AtomicRMWOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
//
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for AtomicRMWOp");
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
auto atomicRmwAttr = op.getAtomicRmwOp();
Value val = op.getVal();
@@ -1352,7 +1374,7 @@ struct AtomicRMWOpConversion
sTy = "b" + sBits;
break;
case RMWOp::ADD:
sTy = "s" + sBits;
sTy = "u" + sBits;
break;
case RMWOp::FADD:
rmwOp = "add";
@@ -1410,9 +1432,9 @@ struct AtomicRMWOpConversion
auto *valOpr = ptxBuilderStore.newOperand(old, tyId);
storeShared(ptrOpr, valOpr).predicate(rmwMask);
ptxBuilderStore.launch(rewriter, loc, void_ty(ctx));
barrier();
createBarrier(rewriter, loc, numCTAs);
Value ret = load(atomPtr);
barrier();
createBarrier(rewriter, loc, numCTAs);
rewriter.replaceOp(op, {ret});
}
}
@@ -1980,17 +2002,6 @@ private:
return bcastMask;
}
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
if (ty.isF16()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if (ty.isF32()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else {
llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op");
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
}
}
unsigned getArgIdx(Value v) const {
if (auto op = v.getDefiningOp<mlir::arith::ConstantOp>()) {
return -1 -