mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge remote-tracking branch 'openai/main' into IFU-230517
Conflicts: lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/test/unit/language/assert_helper.py python/triton/third_party/cuda/bin/ptxas test/Conversion/tritongpu_to_llvm.mlir It looks like you may be committing a merge. If this is not correct, please remove the file .git/MERGE_HEAD and try again.
This commit is contained in:
@@ -13,7 +13,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
// Contains some helper functions for both Load and Store conversions.
|
||||
struct LoadStoreConversionBase {
|
||||
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||
explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass)
|
||||
: axisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
unsigned getContiguity(Value ptr) const {
|
||||
@@ -38,7 +38,7 @@ struct LoadStoreConversionBase {
|
||||
}
|
||||
|
||||
protected:
|
||||
AxisInfoAnalysis &axisAnalysisPass;
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass;
|
||||
};
|
||||
|
||||
struct LoadOpConversion
|
||||
@@ -48,7 +48,8 @@ struct LoadOpConversion
|
||||
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LoadOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
@@ -293,7 +294,8 @@ struct StoreOpConversion
|
||||
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
StoreOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
@@ -335,14 +337,7 @@ struct StoreOpConversion
|
||||
vec = std::min(vec, maskAlign);
|
||||
}
|
||||
|
||||
// numElements = 1 for scalar
|
||||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNBits = dtsize * 8;
|
||||
@@ -431,11 +426,11 @@ struct AtomicCASOpConversion
|
||||
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
ModuleAllocation &allocation,
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(
|
||||
converter, allocation, smem, benefit),
|
||||
converter, allocation, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -526,13 +521,13 @@ struct AtomicCASOpConversion
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, op.getVal().getType());
|
||||
|
||||
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto valueTy = op.getResult().getType();
|
||||
auto TensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
: valueTy;
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(0));
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||
memfence();
|
||||
@@ -552,7 +547,7 @@ struct AtomicCASOpConversion
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||
atom.global().o("cas").o("b32");
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
barrier();
|
||||
|
||||
@@ -561,7 +556,7 @@ struct AtomicCASOpConversion
|
||||
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o("b32");
|
||||
st(dstOprStore, valOprStore).predicate(pred);
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
barrier();
|
||||
@@ -580,11 +575,11 @@ struct AtomicRMWOpConversion
|
||||
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
AtomicRMWOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
ModuleAllocation &allocation,
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(
|
||||
converter, allocation, smem, benefit),
|
||||
converter, allocation, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -747,10 +742,11 @@ struct AtomicRMWOpConversion
|
||||
maskElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llMask, rewriter, op.getMask().getType());
|
||||
|
||||
auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto valueTy = op.getResult().getType();
|
||||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
: valueTy;
|
||||
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getTotalElemsPerThread(val.getType());
|
||||
// vec = 1, numElements = 1 for scalar
|
||||
@@ -763,10 +759,7 @@ struct AtomicRMWOpConversion
|
||||
// mask
|
||||
numElems = tensorTy.getNumElements();
|
||||
}
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
@@ -846,7 +839,6 @@ struct AtomicRMWOpConversion
|
||||
memfenc();
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
@@ -889,7 +881,9 @@ struct InsertSliceOpConversion
|
||||
Value dst = op.getDest();
|
||||
Value src = op.getSource();
|
||||
Value res = op.getResult();
|
||||
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||
auto funcOp = op->getParentOfType<FunctionOpInterface>();
|
||||
auto *funcAllocation = allocation->getFuncData(funcOp);
|
||||
assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||
"Only support in-place insert_slice for now");
|
||||
|
||||
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
|
||||
@@ -949,12 +943,11 @@ struct InsertSliceAsyncOpConversion
|
||||
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
InsertSliceAsyncOpConversion(
|
||||
TritonGPUToLLVMTypeConverter &converter, const Allocation *allocation,
|
||||
Value smem,
|
||||
TritonGPUToLLVMTypeConverter &converter, ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
||||
converter, allocation, smem, indexCacheInfo, benefit),
|
||||
converter, allocation, indexCacheInfo, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
@@ -967,7 +960,9 @@ struct InsertSliceAsyncOpConversion
|
||||
Value res = op.getResult();
|
||||
Value mask = op.getMask();
|
||||
Value other = op.getOther();
|
||||
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||
auto funcOp = op->getParentOfType<FunctionOpInterface>();
|
||||
auto *funcAllocation = allocation->getFuncData(funcOp);
|
||||
assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||
"Only support in-place insert_slice_async for now");
|
||||
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
@@ -1107,19 +1102,17 @@ struct InsertSliceAsyncOpConversion
|
||||
|
||||
void populateLoadStoreOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
|
||||
patterns.add<AtomicCASOpConversion>(typeConverter, allocation,
|
||||
axisInfoAnalysis, benefit);
|
||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
|
||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation,
|
||||
axisInfoAnalysis, benefit);
|
||||
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
|
||||
patterns.add<InsertSliceOpConversion>(typeConverter, allocation,
|
||||
indexCacheInfo, benefit);
|
||||
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
||||
indexCacheInfo, axisInfoAnalysis,
|
||||
benefit);
|
||||
patterns.add<InsertSliceAsyncOpConversion>(
|
||||
typeConverter, allocation, indexCacheInfo, axisInfoAnalysis, benefit);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user