mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] move struct packing/unpacking to type converter and give a more explicit name (#1281)
This is the first of a series of PR meant to clean up how the backend handles the codegen for dot operand layouts
This commit is contained in:
@@ -7,9 +7,7 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getElementsFromStruct;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
@@ -18,19 +16,6 @@ struct LoadStoreConversionBase {
|
||||
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||
: axisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
// Get corresponding LLVM element values of \param value.
|
||||
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) {
|
||||
if (!value)
|
||||
return {};
|
||||
if (!llValue.getType().isa<LLVM::LLVMStructType>())
|
||||
return {llValue};
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
|
||||
return valueVals;
|
||||
}
|
||||
|
||||
unsigned getContiguity(Value ptr) const {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
@@ -62,7 +47,7 @@ struct LoadOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LoadOpConversion(LLVMTypeConverter &converter,
|
||||
LoadOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
@@ -92,13 +77,15 @@ struct LoadOpConversion
|
||||
vec = std::min<size_t>(vec, getMaskAlignment(mask));
|
||||
|
||||
// Get the LLVM values for pointers
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
|
||||
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
|
||||
ptr.getType());
|
||||
assert(ptrElems.size() == numElems);
|
||||
|
||||
// Get the LLVM values for mask
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
|
||||
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
|
||||
mask.getType());
|
||||
assert(maskElems.size() == numElems);
|
||||
}
|
||||
|
||||
@@ -114,7 +101,11 @@ struct LoadOpConversion
|
||||
otherIsSplatConstInt = true;
|
||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||
}
|
||||
auto otherElems = getLLVMElems(other, llOther, rewriter, loc);
|
||||
SmallVector<Value> otherElems;
|
||||
if (other) {
|
||||
otherElems = getTypeConverter()->unpackLLElements(loc, llOther, rewriter,
|
||||
other.getType());
|
||||
}
|
||||
|
||||
// vectorized iteration through all the pointer/mask/other elements
|
||||
const int valueElemNbits =
|
||||
@@ -245,8 +236,8 @@ struct LoadOpConversion
|
||||
} // end vec
|
||||
|
||||
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
||||
Value resultStruct = getTypeConverter()->packLLElements(
|
||||
loc, loadedVals, rewriter, llvmResultStructTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
return success();
|
||||
}
|
||||
@@ -258,7 +249,7 @@ struct StoreOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
StoreOpConversion(LLVMTypeConverter &converter,
|
||||
StoreOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
@@ -284,14 +275,17 @@ struct StoreOpConversion
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
|
||||
auto valueElems = getLLVMElems(value, llValue, rewriter, loc);
|
||||
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
|
||||
ptr.getType());
|
||||
auto valueElems = getTypeConverter()->unpackLLElements(
|
||||
loc, llValue, rewriter, value.getType());
|
||||
assert(ptrElems.size() == valueElems.size());
|
||||
|
||||
// Determine the vectorization size
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
|
||||
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
|
||||
mask.getType());
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
|
||||
unsigned maskAlign = getMaskAlignment(mask);
|
||||
@@ -373,7 +367,7 @@ struct AtomicCASOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
AtomicCASOpConversion(LLVMTypeConverter &converter,
|
||||
AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
@@ -391,9 +385,12 @@ struct AtomicCASOpConversion
|
||||
Value llCmp = adaptor.getCmp();
|
||||
Value llVal = adaptor.getVal();
|
||||
|
||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||
auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter);
|
||||
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||
auto ptrElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llPtr, rewriter, op.getPtr().getType());
|
||||
auto cmpElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llCmp, rewriter, op.getCmp().getType());
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, op.getVal().getType());
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
@@ -447,7 +444,7 @@ struct AtomicRMWOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
AtomicRMWOpConversion(LLVMTypeConverter &converter,
|
||||
AtomicRMWOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
@@ -462,16 +459,21 @@ struct AtomicRMWOpConversion
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto atomicRmwAttr = op.getAtomicRmwOp();
|
||||
Value ptr = op.getPtr();
|
||||
|
||||
Value val = op.getVal();
|
||||
Value ptr = op.getPtr();
|
||||
Value _mask = op.getMask();
|
||||
|
||||
Value llPtr = adaptor.getPtr();
|
||||
Value llVal = adaptor.getVal();
|
||||
Value llMask = adaptor.getMask();
|
||||
|
||||
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, val.getType());
|
||||
auto ptrElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llPtr, rewriter, ptr.getType());
|
||||
auto maskElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llMask, rewriter, _mask.getType());
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
@@ -587,8 +589,8 @@ struct AtomicRMWOpConversion
|
||||
}
|
||||
if (valueTy) {
|
||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
Value resultStruct = getTypeConverter()->packLLElements(
|
||||
loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
}
|
||||
return success();
|
||||
@@ -668,7 +670,8 @@ struct InsertSliceAsyncOpConversion
|
||||
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
InsertSliceAsyncOpConversion(
|
||||
LLVMTypeConverter &converter, const Allocation *allocation, Value smem,
|
||||
TritonGPUToLLVMTypeConverter &converter, const Allocation *allocation,
|
||||
Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
||||
@@ -704,7 +707,8 @@ struct InsertSliceAsyncOpConversion
|
||||
Value llIndex = adaptor.getIndex();
|
||||
|
||||
// %src
|
||||
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
|
||||
auto srcElems = getTypeConverter()->unpackLLElements(loc, llSrc, rewriter,
|
||||
src.getType());
|
||||
|
||||
// %dst
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
@@ -730,7 +734,8 @@ struct InsertSliceAsyncOpConversion
|
||||
// %mask
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
|
||||
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
|
||||
mask.getType());
|
||||
assert(srcElems.size() == maskElems.size());
|
||||
}
|
||||
|
||||
@@ -741,7 +746,8 @@ struct InsertSliceAsyncOpConversion
|
||||
// It's not necessary for now because the pipeline pass will skip
|
||||
// generating insert_slice_async if the load op has any "other" tensor.
|
||||
// assert(false && "insert_slice_async: Other value not supported yet");
|
||||
otherElems = getLLVMElems(other, llOther, rewriter, loc);
|
||||
otherElems = getTypeConverter()->unpackLLElements(loc, llOther, rewriter,
|
||||
other.getType());
|
||||
assert(srcElems.size() == otherElems.size());
|
||||
}
|
||||
|
||||
@@ -821,7 +827,7 @@ struct InsertSliceAsyncOpConversion
|
||||
};
|
||||
|
||||
void populateLoadStoreOpToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
|
||||
Reference in New Issue
Block a user