|
|
|
|
@@ -41,16 +41,14 @@ Type getShemPtrTy(Type elemTy) {
|
|
|
|
|
return ptr_ty(elemTy, 3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get a waveId for M axis.
|
|
|
|
|
Value getWaveM(ConversionPatternRewriter &rewriter, Location loc, Value wave,
|
|
|
|
|
const ArrayRef<unsigned int> &wpt, int elemPerInstr, int M) {
|
|
|
|
|
return urem(urem(wave, i32_val(wpt[0])), i32_val(M / elemPerInstr));
|
|
|
|
|
}
|
|
|
|
|
// Get a waveId for N axis.
|
|
|
|
|
Value getWaveN(ConversionPatternRewriter &rewriter, Location loc, Value wave,
|
|
|
|
|
const ArrayRef<unsigned int> &wpt, int elemPerInstr, int N) {
|
|
|
|
|
Value waveMN = udiv(wave, i32_val(wpt[0]));
|
|
|
|
|
return urem(urem(waveMN, i32_val(wpt[1])), i32_val(N / elemPerInstr));
|
|
|
|
|
// Get waveId inside block of waves.
|
|
|
|
|
Value getWaveIdInBlock(ConversionPatternRewriter &rewriter, Location loc,
|
|
|
|
|
Value waveId, const ArrayRef<unsigned int> &wpt,
|
|
|
|
|
int elemPerInstrNonK, int tensorSizeNonK, int nonKIdx) {
|
|
|
|
|
if (nonKIdx == 1)
|
|
|
|
|
waveId = udiv(waveId, i32_val(wpt[0]));
|
|
|
|
|
return urem(urem(waveId, i32_val(wpt[nonKIdx])),
|
|
|
|
|
i32_val(tensorSizeNonK / elemPerInstrNonK));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
@@ -435,319 +433,137 @@ bool isTransposed(::llvm::ArrayRef<unsigned> order) {
|
|
|
|
|
return order[0] == 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
|
|
|
|
DotOperandEncodingAttr encoding,
|
|
|
|
|
TritonGPUToLLVMTypeConverter *typeConverter, Value tensor,
|
|
|
|
|
const SharedMemoryObject &smemObj) {
|
|
|
|
|
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
|
|
|
|
|
Location loc, Value tensor, DotOperandEncodingAttr encoding,
|
|
|
|
|
const SharedMemoryObject &smemObj,
|
|
|
|
|
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
|
|
|
|
|
assert((opIdx == 0 || opIdx == 1) && "unexpected operand idx");
|
|
|
|
|
|
|
|
|
|
int kDimIdx = opIdx == 0 ? 1 : 0;
|
|
|
|
|
int nonKDimIdx = opIdx == 0 ? 0 : 1;
|
|
|
|
|
|
|
|
|
|
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
|
|
|
|
|
auto nonKDim = mfmaLayout.getNonKDim();
|
|
|
|
|
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
|
|
|
|
|
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
|
|
|
|
|
|
|
|
|
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
|
|
|
|
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
|
|
|
|
aTensorTy.getShape().end());
|
|
|
|
|
ArrayRef<int64_t> shape = aTensorTy.getShape();
|
|
|
|
|
auto sharedLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
auto order = sharedLayout.getOrder();
|
|
|
|
|
|
|
|
|
|
auto aElemTy = aTensorTy.getElementType();
|
|
|
|
|
auto aElemsPerInstr = encoding.getMFMAElemsPerInstr();
|
|
|
|
|
auto mfmaInstrM = aElemsPerInstr[0];
|
|
|
|
|
auto mfmaInstrK = aElemsPerInstr[1];
|
|
|
|
|
auto elemTy = aTensorTy.getElementType();
|
|
|
|
|
auto elemsPerInstr = encoding.getMFMAElemsPerInstr();
|
|
|
|
|
auto mfmaInstrNonK = elemsPerInstr[nonKDimIdx];
|
|
|
|
|
auto mfmaInstrK = elemsPerInstr[kDimIdx];
|
|
|
|
|
|
|
|
|
|
auto numReps = encoding.getMFMARep(shape, aElemTy);
|
|
|
|
|
auto numRepM = numReps[0];
|
|
|
|
|
auto numRepK = numReps[1];
|
|
|
|
|
auto numReps = encoding.getMFMARep(shape);
|
|
|
|
|
auto numRepNonK = numReps[nonKDimIdx];
|
|
|
|
|
auto numRepK = numReps[kDimIdx];
|
|
|
|
|
|
|
|
|
|
unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout);
|
|
|
|
|
assert(iWaveSize == 64);
|
|
|
|
|
Value waveSize = i32_val(iWaveSize);
|
|
|
|
|
Value wave = udiv(thread, waveSize);
|
|
|
|
|
Value linearWaveId = udiv(thread, waveSize);
|
|
|
|
|
Value lane = urem(thread, waveSize);
|
|
|
|
|
|
|
|
|
|
Value waveM =
|
|
|
|
|
getWaveM(rewriter, loc, wave, warpsPerCTA, mfmaInstrM, shape[0]);
|
|
|
|
|
int numOfElems = mfmaInstrM * mfmaInstrK / iWaveSize;
|
|
|
|
|
Value spatialWaveId =
|
|
|
|
|
getWaveIdInBlock(rewriter, loc, linearWaveId, warpsPerCTA, mfmaInstrNonK,
|
|
|
|
|
shape[nonKDimIdx], nonKDimIdx);
|
|
|
|
|
int numOfElems = mfmaInstrNonK * mfmaInstrK / iWaveSize;
|
|
|
|
|
assert(numOfElems >= 1);
|
|
|
|
|
unsigned int maxNumWarps = shape[0] / mfmaInstrM;
|
|
|
|
|
int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps);
|
|
|
|
|
aElemTy = typeConverter->convertType(aElemTy);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> ha;
|
|
|
|
|
unsigned int maxNumWarps = shape[nonKDimIdx] / mfmaInstrNonK;
|
|
|
|
|
int warpsPerGroupNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps);
|
|
|
|
|
elemTy = typeConverter->convertType(elemTy);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> loadedValues;
|
|
|
|
|
SmallVector<Value> offsets;
|
|
|
|
|
Value smemBase;
|
|
|
|
|
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
|
|
|
|
|
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
|
|
|
|
SmallVector<Value> offsets;
|
|
|
|
|
if (isTransposed(order)) {
|
|
|
|
|
SmallVector<int64_t> elemsPerInstr{mfmaInstrK, mfmaInstrM};
|
|
|
|
|
SmallVector<int64_t> reps{numReps[1], numReps[0]};
|
|
|
|
|
offsets = fastPathComputeOffsetsTy2(rewriter, loc, elemsPerInstr, waveM,
|
|
|
|
|
lane, warpsPerGroupM, numOfElems,
|
|
|
|
|
reps, cSwizzleOffset);
|
|
|
|
|
if (opIdx == 0) {
|
|
|
|
|
if (isTransposed(order)) { // HERE
|
|
|
|
|
SmallVector<int64_t> elemsPerInstr{mfmaInstrK, mfmaInstrNonK};
|
|
|
|
|
SmallVector<int64_t> reps{numReps[1], numReps[0]};
|
|
|
|
|
offsets = fastPathComputeOffsetsTy2(
|
|
|
|
|
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
|
|
|
|
|
warpsPerGroupNonK, numOfElems, reps, cSwizzleOffset);
|
|
|
|
|
} else {
|
|
|
|
|
offsets = fastPathComputeOffsetsTy1(
|
|
|
|
|
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
|
|
|
|
|
warpsPerGroupNonK, numOfElems, numReps, cSwizzleOffset);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
offsets = fastPathComputeOffsetsTy1(rewriter, loc, aElemsPerInstr, waveM,
|
|
|
|
|
lane, warpsPerGroupM, numOfElems,
|
|
|
|
|
numReps, cSwizzleOffset);
|
|
|
|
|
}
|
|
|
|
|
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
|
|
|
|
|
|
|
|
|
Type smemPtrTy = getShemPtrTy(aElemTy);
|
|
|
|
|
Type resElemTy = typeConverter->convertType(aElemTy);
|
|
|
|
|
|
|
|
|
|
int loadsPerThread = offsets.size() / (numRepM * numRepK);
|
|
|
|
|
const int elemsPerLoad = numOfElems / loadsPerThread;
|
|
|
|
|
assert(numOfElems % loadsPerThread == 0);
|
|
|
|
|
|
|
|
|
|
for (int m = 0; m < numRepM; ++m) {
|
|
|
|
|
for (int k = 0; k < numRepK; ++k) {
|
|
|
|
|
auto vecTy = vec_ty(resElemTy, numOfElems);
|
|
|
|
|
Value valVec = undef(vecTy);
|
|
|
|
|
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
|
|
|
|
auto loadVecTy = vec_ty(aElemTy, elemsPerLoad);
|
|
|
|
|
Value loadOffset =
|
|
|
|
|
offsets[m * loadsPerThread * numRepK + k * loadsPerThread + loadId];
|
|
|
|
|
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
|
|
|
|
|
getShemPtrTy(loadVecTy));
|
|
|
|
|
Value vectorValue = load(loadAddress);
|
|
|
|
|
if (numOfElems > 1) {
|
|
|
|
|
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
|
|
|
|
Value elemVal =
|
|
|
|
|
extract_element(aElemTy, vectorValue, i32_val(elemId));
|
|
|
|
|
elemVal = bitcast(elemVal, resElemTy);
|
|
|
|
|
valVec = insert_element(vecTy, valVec, elemVal,
|
|
|
|
|
i32_val(loadId * elemsPerLoad + elemId));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
valVec = extract_element(aElemTy, vectorValue, i32_val(0));
|
|
|
|
|
valVec = bitcast(valVec, resElemTy);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (aElemTy == i8_ty && numOfElems == 4)
|
|
|
|
|
valVec = bitcast(valVec, i32_ty);
|
|
|
|
|
if (aElemTy == i8_ty && numOfElems == 8)
|
|
|
|
|
valVec = bitcast(valVec, i64_ty);
|
|
|
|
|
ha.push_back(valVec);
|
|
|
|
|
if (isTransposed(order)) {
|
|
|
|
|
SmallVector<int64_t> elemsPerInstr{mfmaInstrNonK, mfmaInstrK};
|
|
|
|
|
SmallVector<int64_t> reps{numReps[1], numReps[0]};
|
|
|
|
|
offsets = fastPathComputeOffsetsTy1(
|
|
|
|
|
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
|
|
|
|
|
warpsPerGroupNonK, numOfElems, reps, cSwizzleOffset);
|
|
|
|
|
} else {
|
|
|
|
|
offsets = fastPathComputeOffsetsTy2(
|
|
|
|
|
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
|
|
|
|
|
warpsPerGroupNonK, numOfElems, numReps, cSwizzleOffset);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
|
|
|
|
} else { // normal path
|
|
|
|
|
SmallVector<Value> offsets = computeOffsetsAType(
|
|
|
|
|
rewriter, loc, aElemsPerInstr, waveM, lane, warpsPerGroupM, numOfElems,
|
|
|
|
|
numReps, smemObj, sharedLayout, nonKDim);
|
|
|
|
|
if (opIdx == 0) {
|
|
|
|
|
offsets = computeOffsetsAType(rewriter, loc, elemsPerInstr, spatialWaveId,
|
|
|
|
|
lane, warpsPerGroupNonK, numOfElems,
|
|
|
|
|
numReps, smemObj, sharedLayout, nonKDim);
|
|
|
|
|
} else {
|
|
|
|
|
assert(opIdx == 1);
|
|
|
|
|
offsets = computeOffsetsBType(rewriter, loc, elemsPerInstr, spatialWaveId,
|
|
|
|
|
lane, warpsPerGroupNonK, numOfElems,
|
|
|
|
|
numReps, smemObj, sharedLayout, nonKDim);
|
|
|
|
|
}
|
|
|
|
|
smemBase = computeBasePtr(rewriter, loc, smemObj);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
|
|
|
|
|
Type resElemTy = typeConverter->convertType(aElemTy);
|
|
|
|
|
Type resElemTy = typeConverter->convertType(elemTy);
|
|
|
|
|
Type smemPtrTy = getShemPtrTy(elemTy);
|
|
|
|
|
|
|
|
|
|
Type smemPtrTy = getShemPtrTy(aElemTy);
|
|
|
|
|
int loadsPerThread = offsets.size() / (numRepNonK * numRepK);
|
|
|
|
|
int elemsPerLoad = numOfElems / loadsPerThread;
|
|
|
|
|
assert(numOfElems % loadsPerThread == 0);
|
|
|
|
|
|
|
|
|
|
int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]);
|
|
|
|
|
int elemsPerLoad = numOfElems / loadsPerThread;
|
|
|
|
|
|
|
|
|
|
for (int m = 0; m < numRepM; ++m) {
|
|
|
|
|
for (int k = 0; k < numRepK; ++k) {
|
|
|
|
|
auto vecTy = vec_ty(resElemTy, numOfElems);
|
|
|
|
|
Value valVec = undef(vecTy);
|
|
|
|
|
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
|
|
|
|
auto loadVecTy = vec_ty(aElemTy, elemsPerLoad);
|
|
|
|
|
Value loadOffset = offsets[m * loadsPerThread * numRepK +
|
|
|
|
|
k * loadsPerThread + loadId];
|
|
|
|
|
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
|
|
|
|
|
getShemPtrTy(loadVecTy));
|
|
|
|
|
Value vectorValue = load(loadAddress);
|
|
|
|
|
if (numOfElems > 1) {
|
|
|
|
|
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
|
|
|
|
Value elemVal =
|
|
|
|
|
extract_element(aElemTy, vectorValue, i32_val(elemId));
|
|
|
|
|
elemVal = bitcast(elemVal, resElemTy);
|
|
|
|
|
valVec = insert_element(vecTy, valVec, elemVal,
|
|
|
|
|
i32_val(loadId * elemsPerLoad + elemId));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
valVec = extract_element(aElemTy, vectorValue, i32_val(0));
|
|
|
|
|
valVec = bitcast(valVec, resElemTy);
|
|
|
|
|
for (int nonK = 0; nonK < numRepNonK; ++nonK) {
|
|
|
|
|
for (int k = 0; k < numRepK; ++k) {
|
|
|
|
|
auto vecTy = vec_ty(resElemTy, numOfElems);
|
|
|
|
|
Value valVec = undef(vecTy);
|
|
|
|
|
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
|
|
|
|
auto loadVecTy = vec_ty(elemTy, elemsPerLoad);
|
|
|
|
|
Value loadOffset = offsets[nonK * loadsPerThread * numRepK +
|
|
|
|
|
k * loadsPerThread + loadId];
|
|
|
|
|
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
|
|
|
|
|
getShemPtrTy(loadVecTy));
|
|
|
|
|
Value loadedValue = load(loadAddress);
|
|
|
|
|
if (loadsPerThread > 1) {
|
|
|
|
|
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
|
|
|
|
Value elemVal =
|
|
|
|
|
extract_element(elemTy, loadedValue, i32_val(elemId));
|
|
|
|
|
elemVal = bitcast(elemVal, resElemTy);
|
|
|
|
|
valVec = insert_element(vecTy, valVec, elemVal,
|
|
|
|
|
i32_val(loadId * elemsPerLoad + elemId));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
valVec = loadedValue;
|
|
|
|
|
}
|
|
|
|
|
if (aElemTy == i8_ty && numOfElems == 4)
|
|
|
|
|
valVec = bitcast(valVec, i32_ty);
|
|
|
|
|
if (aElemTy == i8_ty && numOfElems == 8)
|
|
|
|
|
valVec = bitcast(valVec, i64_ty);
|
|
|
|
|
ha.push_back(valVec);
|
|
|
|
|
}
|
|
|
|
|
loadedValues.push_back(valVec);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MLIRContext *ctx = mfmaLayout.getContext();
|
|
|
|
|
Type structTy = LLVM::LLVMStructType::getLiteral(
|
|
|
|
|
ctx, SmallVector<Type>(ha.size(), ha[0].getType()));
|
|
|
|
|
auto result = typeConverter->packLLElements(loc, ha, rewriter, structTy);
|
|
|
|
|
ctx, SmallVector<Type>(loadedValues.size(), loadedValues[0].getType()));
|
|
|
|
|
auto result =
|
|
|
|
|
typeConverter->packLLElements(loc, loadedValues, rewriter, structTy);
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
|
|
|
|
DotOperandEncodingAttr encoding,
|
|
|
|
|
TritonGPUToLLVMTypeConverter *typeConverter, Value tensor,
|
|
|
|
|
const SharedMemoryObject &smemObj) {
|
|
|
|
|
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
|
|
|
|
|
auto nonKDim = mfmaLayout.getNonKDim();
|
|
|
|
|
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
|
|
|
|
|
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
|
|
|
|
|
|
|
|
|
auto bTensorTy = tensor.getType().cast<RankedTensorType>();
|
|
|
|
|
ArrayRef<int64_t> shape = bTensorTy.getShape();
|
|
|
|
|
auto sharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
auto order = sharedLayout.getOrder();
|
|
|
|
|
|
|
|
|
|
auto bElemTy = bTensorTy.getElementType();
|
|
|
|
|
auto bElemsPerInstr = encoding.getMFMAElemsPerInstr();
|
|
|
|
|
auto mfmaInstrK = bElemsPerInstr[0];
|
|
|
|
|
auto mfmaInstrN = bElemsPerInstr[1];
|
|
|
|
|
|
|
|
|
|
auto numReps = encoding.getMFMARep(shape, bElemTy);
|
|
|
|
|
auto numRepK = numReps[0];
|
|
|
|
|
auto numRepN = numReps[1];
|
|
|
|
|
|
|
|
|
|
unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout);
|
|
|
|
|
assert(iWaveSize == 64);
|
|
|
|
|
Value waveSize = i32_val(iWaveSize);
|
|
|
|
|
Value wave = udiv(thread, waveSize);
|
|
|
|
|
Value lane = urem(thread, waveSize);
|
|
|
|
|
|
|
|
|
|
Value waveN =
|
|
|
|
|
getWaveN(rewriter, loc, wave, warpsPerCTA, mfmaInstrN, shape[1]);
|
|
|
|
|
int numOfElems = mfmaInstrK * mfmaInstrN / iWaveSize;
|
|
|
|
|
assert(numOfElems >= 1);
|
|
|
|
|
|
|
|
|
|
unsigned int maxNumWarps = shape[1] / mfmaInstrN;
|
|
|
|
|
int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps);
|
|
|
|
|
bElemTy = typeConverter->convertType(bElemTy);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> hb;
|
|
|
|
|
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
|
|
|
|
|
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> offsets;
|
|
|
|
|
unsigned int maxNumWarps = shape[1] / mfmaInstrN;
|
|
|
|
|
int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps);
|
|
|
|
|
if (isTransposed(order)) {
|
|
|
|
|
SmallVector<int64_t> elemsPerInstr{mfmaInstrN, mfmaInstrK};
|
|
|
|
|
SmallVector<int64_t> reps{numReps[1], numReps[0]};
|
|
|
|
|
offsets = fastPathComputeOffsetsTy1(rewriter, loc, elemsPerInstr, waveN,
|
|
|
|
|
lane, warpsPerGroupN, numOfElems,
|
|
|
|
|
reps, cSwizzleOffset);
|
|
|
|
|
} else {
|
|
|
|
|
offsets = fastPathComputeOffsetsTy2(rewriter, loc, bElemsPerInstr, waveN,
|
|
|
|
|
lane, warpsPerGroupN, numOfElems,
|
|
|
|
|
numReps, cSwizzleOffset);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
|
|
|
|
|
|
|
|
|
Type resElemTy = typeConverter->convertType(bElemTy);
|
|
|
|
|
Type smemPtrTy = getShemPtrTy(bElemTy);
|
|
|
|
|
|
|
|
|
|
const int loadsPerThread = offsets.size() / (numRepN * numRepK);
|
|
|
|
|
const int elemsPerLoad = numOfElems / loadsPerThread;
|
|
|
|
|
assert(numOfElems % loadsPerThread == 0);
|
|
|
|
|
|
|
|
|
|
for (int n = 0; n < numRepN; ++n) {
|
|
|
|
|
for (int k = 0; k < numRepK; ++k) {
|
|
|
|
|
auto vecTy = vec_ty(resElemTy, numOfElems);
|
|
|
|
|
Value valVec = undef(vecTy);
|
|
|
|
|
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
|
|
|
|
auto loadVecTy = vec_ty(bElemTy, elemsPerLoad);
|
|
|
|
|
Value loadOffset =
|
|
|
|
|
offsets[n * loadsPerThread * numRepK + k * loadsPerThread + loadId];
|
|
|
|
|
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
|
|
|
|
|
getShemPtrTy(loadVecTy));
|
|
|
|
|
Value vectorValue = load(loadAddress);
|
|
|
|
|
if (numOfElems > 1) {
|
|
|
|
|
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
|
|
|
|
Value elemVal =
|
|
|
|
|
extract_element(bElemTy, vectorValue, i32_val(elemId));
|
|
|
|
|
elemVal = bitcast(elemVal, resElemTy);
|
|
|
|
|
valVec = insert_element(vecTy, valVec, elemVal,
|
|
|
|
|
i32_val(loadId * elemsPerLoad + elemId));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
valVec = extract_element(bElemTy, vectorValue, i32_val(0));
|
|
|
|
|
valVec = bitcast(valVec, resElemTy);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (bElemTy == i8_ty && numOfElems == 4)
|
|
|
|
|
valVec = bitcast(valVec, i32_ty);
|
|
|
|
|
if (bElemTy == i8_ty && numOfElems == 8)
|
|
|
|
|
valVec = bitcast(valVec, i64_ty);
|
|
|
|
|
hb.push_back(valVec);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else { // normal path
|
|
|
|
|
llvm::SmallVector<Value> offsets = computeOffsetsBType(
|
|
|
|
|
rewriter, loc, bElemsPerInstr, waveN, lane, warpsPerGroupN, numOfElems,
|
|
|
|
|
numReps, smemObj, sharedLayout, nonKDim);
|
|
|
|
|
|
|
|
|
|
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
|
|
|
|
|
Type resElemTy = typeConverter->convertType(bElemTy);
|
|
|
|
|
Type smemPtrTy = getShemPtrTy(bElemTy);
|
|
|
|
|
|
|
|
|
|
int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]);
|
|
|
|
|
int elemsPerLoad = numOfElems / loadsPerThread;
|
|
|
|
|
for (int n = 0; n < numRepN; ++n) {
|
|
|
|
|
for (int k = 0; k < numRepK; ++k) {
|
|
|
|
|
auto vecTy = vec_ty(resElemTy, numOfElems);
|
|
|
|
|
Value valVec = undef(vecTy);
|
|
|
|
|
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
|
|
|
|
auto loadVecTy = vec_ty(bElemTy, elemsPerLoad);
|
|
|
|
|
Value loadOffset = offsets[n * loadsPerThread * numRepK +
|
|
|
|
|
k * loadsPerThread + loadId];
|
|
|
|
|
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
|
|
|
|
|
getShemPtrTy(loadVecTy));
|
|
|
|
|
Value vectorValue = load(loadAddress);
|
|
|
|
|
if (numOfElems > 1) {
|
|
|
|
|
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
|
|
|
|
Value elemVal =
|
|
|
|
|
extract_element(bElemTy, vectorValue, i32_val(elemId));
|
|
|
|
|
elemVal = bitcast(elemVal, resElemTy);
|
|
|
|
|
valVec = insert_element(vecTy, valVec, elemVal,
|
|
|
|
|
i32_val(loadId * elemsPerLoad + elemId));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
valVec = extract_element(bElemTy, vectorValue, i32_val(0));
|
|
|
|
|
valVec = bitcast(valVec, resElemTy);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (bElemTy == i8_ty && numOfElems == 4)
|
|
|
|
|
valVec = bitcast(valVec, i32_ty);
|
|
|
|
|
if (bElemTy == i8_ty && numOfElems == 8)
|
|
|
|
|
valVec = bitcast(valVec, i64_ty);
|
|
|
|
|
hb.push_back(valVec);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MLIRContext *ctx = mfmaLayout.getContext();
|
|
|
|
|
Type structTy = LLVM::LLVMStructType::getLiteral(
|
|
|
|
|
ctx, SmallVector<Type>(hb.size(), hb[0].getType()));
|
|
|
|
|
auto result = typeConverter->packLLElements(loc, hb, rewriter, structTy);
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
|
|
|
|
|
Location loc, Value tensor, DotOperandEncodingAttr encoding,
|
|
|
|
|
const SharedMemoryObject &smemObj,
|
|
|
|
|
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
|
|
|
|
|
switch (opIdx) {
|
|
|
|
|
case 0:
|
|
|
|
|
// operand $a
|
|
|
|
|
return loadA(rewriter, loc, thread, encoding, typeConverter, tensor,
|
|
|
|
|
smemObj);
|
|
|
|
|
case 1:
|
|
|
|
|
// operand $b
|
|
|
|
|
return loadB(rewriter, loc, thread, encoding, typeConverter, tensor,
|
|
|
|
|
smemObj);
|
|
|
|
|
default:
|
|
|
|
|
assert(false && "unexpected operand idx");
|
|
|
|
|
return Value();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace SharedToDotOperandMFMA
|
|
|
|
|
|
|
|
|
|
#endif // ifdef USE_ROCM
|
|
|
|
|
|