[BACKEND] move struct packing/unpacking to type converter and give a more explicit name (#1281)

This is the first of a series of PR meant to clean up how the backend
handles the codegen for dot operand layouts
This commit is contained in:
Philippe Tillet
2023-03-05 16:04:29 -08:00
committed by GitHub
parent d376020f90
commit 0f82fac60e
22 changed files with 442 additions and 410 deletions

View File

@@ -7,9 +7,7 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -18,19 +16,6 @@ struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
: axisAnalysisPass(axisAnalysisPass) {}
// Get corresponding LLVM element values of \param value.
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
ConversionPatternRewriter &rewriter,
Location loc) {
if (!value)
return {};
if (!llValue.getType().isa<LLVM::LLVMStructType>())
return {llValue};
// Here, we assume that all inputs should have a blockedLayout
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
return valueVals;
}
unsigned getContiguity(Value ptr) const {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
@@ -62,7 +47,7 @@ struct LoadOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
LoadOpConversion(LLVMTypeConverter &converter,
LoadOpConversion(TritonGPUToLLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
@@ -92,13 +77,15 @@ struct LoadOpConversion
vec = std::min<size_t>(vec, getMaskAlignment(mask));
// Get the LLVM values for pointers
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
ptr.getType());
assert(ptrElems.size() == numElems);
// Get the LLVM values for mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
mask.getType());
assert(maskElems.size() == numElems);
}
@@ -114,7 +101,11 @@ struct LoadOpConversion
otherIsSplatConstInt = true;
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
}
auto otherElems = getLLVMElems(other, llOther, rewriter, loc);
SmallVector<Value> otherElems;
if (other) {
otherElems = getTypeConverter()->unpackLLElements(loc, llOther, rewriter,
other.getType());
}
// vectorized iteration through all the pointer/mask/other elements
const int valueElemNbits =
@@ -245,8 +236,8 @@ struct LoadOpConversion
} // end vec
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
Value resultStruct = getTypeConverter()->packLLElements(
loc, loadedVals, rewriter, llvmResultStructTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
@@ -258,7 +249,7 @@ struct StoreOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
StoreOpConversion(LLVMTypeConverter &converter,
StoreOpConversion(TritonGPUToLLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
@@ -284,14 +275,17 @@ struct StoreOpConversion
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
auto valueElems = getLLVMElems(value, llValue, rewriter, loc);
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
ptr.getType());
auto valueElems = getTypeConverter()->unpackLLElements(
loc, llValue, rewriter, value.getType());
assert(ptrElems.size() == valueElems.size());
// Determine the vectorization size
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
mask.getType());
assert(valueElems.size() == maskElems.size());
unsigned maskAlign = getMaskAlignment(mask);
@@ -373,7 +367,7 @@ struct AtomicCASOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicCASOpConversion(LLVMTypeConverter &converter,
AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
@@ -391,9 +385,12 @@ struct AtomicCASOpConversion
Value llCmp = adaptor.getCmp();
Value llVal = adaptor.getVal();
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter);
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getTypeConverter()->unpackLLElements(
loc, llPtr, rewriter, op.getPtr().getType());
auto cmpElements = getTypeConverter()->unpackLLElements(
loc, llCmp, rewriter, op.getCmp().getType());
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, op.getVal().getType());
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
@@ -447,7 +444,7 @@ struct AtomicRMWOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicRMWOpConversion(LLVMTypeConverter &converter,
AtomicRMWOpConversion(TritonGPUToLLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
@@ -462,16 +459,21 @@ struct AtomicRMWOpConversion
MLIRContext *ctx = rewriter.getContext();
auto atomicRmwAttr = op.getAtomicRmwOp();
Value ptr = op.getPtr();
Value val = op.getVal();
Value ptr = op.getPtr();
Value _mask = op.getMask();
Value llPtr = adaptor.getPtr();
Value llVal = adaptor.getVal();
Value llMask = adaptor.getMask();
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, val.getType());
auto ptrElements = getTypeConverter()->unpackLLElements(
loc, llPtr, rewriter, ptr.getType());
auto maskElements = getTypeConverter()->unpackLLElements(
loc, llMask, rewriter, _mask.getType());
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
@@ -587,8 +589,8 @@ struct AtomicRMWOpConversion
}
if (valueTy) {
Type structTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, structTy);
Value resultStruct = getTypeConverter()->packLLElements(
loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
}
return success();
@@ -668,7 +670,8 @@ struct InsertSliceAsyncOpConversion
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncOpConversion(
LLVMTypeConverter &converter, const Allocation *allocation, Value smem,
TritonGPUToLLVMTypeConverter &converter, const Allocation *allocation,
Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
@@ -704,7 +707,8 @@ struct InsertSliceAsyncOpConversion
Value llIndex = adaptor.getIndex();
// %src
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
auto srcElems = getTypeConverter()->unpackLLElements(loc, llSrc, rewriter,
src.getType());
// %dst
auto dstTy = dst.getType().cast<RankedTensorType>();
@@ -730,7 +734,8 @@ struct InsertSliceAsyncOpConversion
// %mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
mask.getType());
assert(srcElems.size() == maskElems.size());
}
@@ -741,7 +746,8 @@ struct InsertSliceAsyncOpConversion
// It's not necessary for now because the pipeline pass will skip
// generating insert_slice_async if the load op has any "other" tensor.
// assert(false && "insert_slice_async: Other value not supported yet");
otherElems = getLLVMElems(other, llOther, rewriter, loc);
otherElems = getTypeConverter()->unpackLLElements(loc, llOther, rewriter,
other.getType());
assert(srcElems.size() == otherElems.size());
}
@@ -821,7 +827,7 @@ struct InsertSliceAsyncOpConversion
};
void populateLoadStoreOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,