Merge remote-tracking branch 'openai/main' into IFU-230517

Conflicts:
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/test/unit/language/assert_helper.py
	python/triton/third_party/cuda/bin/ptxas
	test/Conversion/tritongpu_to_llvm.mlir

 It looks like you may be committing a merge.
 If this is not correct, please remove the file
	.git/MERGE_HEAD
 and try again.
This commit is contained in:
Jason Furmanek
2023-05-17 15:03:42 +00:00
99 changed files with 4561 additions and 1251 deletions

View File

@@ -13,7 +13,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass)
: axisAnalysisPass(axisAnalysisPass) {}
unsigned getContiguity(Value ptr) const {
@@ -38,7 +38,7 @@ struct LoadStoreConversionBase {
}
protected:
AxisInfoAnalysis &axisAnalysisPass;
ModuleAxisInfoAnalysis &axisAnalysisPass;
};
struct LoadOpConversion
@@ -48,7 +48,8 @@ struct LoadOpConversion
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
LoadOpConversion(TritonGPUToLLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
@@ -293,7 +294,8 @@ struct StoreOpConversion
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
StoreOpConversion(TritonGPUToLLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
@@ -335,14 +337,7 @@ struct StoreOpConversion
vec = std::min(vec, maskAlign);
}
// numElements = 1 for scalar
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
Value mask = getMask(valueTy, rewriter, loc);
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNBits = dtsize * 8;
@@ -431,11 +426,11 @@ struct AtomicCASOpConversion
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
ModuleAllocation &allocation,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(
converter, allocation, smem, benefit),
converter, allocation, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
#ifdef USE_ROCM
@@ -526,13 +521,13 @@ struct AtomicCASOpConversion
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, op.getVal().getType());
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto valueTy = op.getResult().getType();
auto TensorTy = valueTy.dyn_cast<RankedTensorType>();
Type valueElemTy =
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
: op.getResult().getType();
: valueTy;
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
Value mask = getMask(valueTy, rewriter, loc);
PTXBuilder ptxBuilderMemfence;
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfence();
@@ -552,7 +547,7 @@ struct AtomicCASOpConversion
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
atom.global().o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();
@@ -561,7 +556,7 @@ struct AtomicCASOpConversion
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(pred);
st(dstOprStore, valOprStore).predicate(mask);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
barrier();
@@ -580,11 +575,11 @@ struct AtomicRMWOpConversion
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicRMWOpConversion(TritonGPUToLLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
ModuleAllocation &allocation,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(
converter, allocation, smem, benefit),
converter, allocation, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
#ifdef USE_ROCM
@@ -747,10 +742,11 @@ struct AtomicRMWOpConversion
maskElements = getTypeConverter()->unpackLLElements(
loc, llMask, rewriter, op.getMask().getType());
auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto valueTy = op.getResult().getType();
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
Type valueElemTy =
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: op.getResult().getType();
: valueTy;
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getTotalElemsPerThread(val.getType());
// vec = 1, numElements = 1 for scalar
@@ -763,10 +759,7 @@ struct AtomicRMWOpConversion
// mask
numElems = tensorTy.getNumElements();
}
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
Value mask = getMask(valueTy, rewriter, loc);
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
@@ -846,7 +839,6 @@ struct AtomicRMWOpConversion
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
@@ -889,7 +881,9 @@ struct InsertSliceOpConversion
Value dst = op.getDest();
Value src = op.getSource();
Value res = op.getResult();
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
auto funcOp = op->getParentOfType<FunctionOpInterface>();
auto *funcAllocation = allocation->getFuncData(funcOp);
assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId &&
"Only support in-place insert_slice for now");
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
@@ -949,12 +943,11 @@ struct InsertSliceAsyncOpConversion
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncOpConversion(
TritonGPUToLLVMTypeConverter &converter, const Allocation *allocation,
Value smem,
TritonGPUToLLVMTypeConverter &converter, ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
converter, allocation, smem, indexCacheInfo, benefit),
converter, allocation, indexCacheInfo, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
@@ -967,7 +960,9 @@ struct InsertSliceAsyncOpConversion
Value res = op.getResult();
Value mask = op.getMask();
Value other = op.getOther();
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
auto funcOp = op->getParentOfType<FunctionOpInterface>();
auto *funcAllocation = allocation->getFuncData(funcOp);
assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId &&
"Only support in-place insert_slice_async for now");
auto srcTy = src.getType().cast<RankedTensorType>();
@@ -1107,19 +1102,17 @@ struct InsertSliceAsyncOpConversion
void populateLoadStoreOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
patterns.add<AtomicCASOpConversion>(typeConverter, allocation,
axisInfoAnalysis, benefit);
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation,
axisInfoAnalysis, benefit);
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
patterns.add<InsertSliceOpConversion>(typeConverter, allocation,
indexCacheInfo, benefit);
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
indexCacheInfo, axisInfoAnalysis,
benefit);
patterns.add<InsertSliceAsyncOpConversion>(
typeConverter, allocation, indexCacheInfo, axisInfoAnalysis, benefit);
}