mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER][BACKEND] Cleaned up Volta codegen (#1185)
This commit is contained in:
@@ -159,7 +159,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
return M / shapePerCTAM * param.rep[0];
|
||||
}
|
||||
|
||||
using CoordTy = SmallVector<Value, 2>;
|
||||
using CoordTy = SmallVector<Value>;
|
||||
// Get the coordinates(m,n) of the elements emit by a thread in accumulator.
|
||||
static SmallVector<CoordTy>
|
||||
getMNCoords(Value thread, ConversionPatternRewriter &rewriter,
|
||||
|
||||
@@ -69,36 +69,6 @@ struct BroadcastOpConversion
|
||||
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
|
||||
auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape);
|
||||
SmallVector<Value> srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||
if (auto srcMma = srcLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
|
||||
// NOTE: This is just an naive fix, but for MMA layout, and 2-d fix should
|
||||
// be all right.
|
||||
// TODO[Superjomn]: Replace this with a generic implementation.
|
||||
if (srcMma.isVolta()) {
|
||||
assert(srcTy.getElementType().isF16() &&
|
||||
"Unexpected data type on Volta");
|
||||
int numElemsPerThread = srcMma.getElemsPerThread(resultTy.getShape());
|
||||
int srcUniqElems = srcVals.size() / 2;
|
||||
int dup = numElemsPerThread / srcUniqElems;
|
||||
SmallVector<Value> retVals;
|
||||
if (srcShape[0] == 1) { // add-cols
|
||||
for (int i = 0; i < srcUniqElems; ++i)
|
||||
for (int k = 0; k < dup; ++k)
|
||||
retVals.push_back(srcVals[i * 2]);
|
||||
|
||||
} else { // add-rows
|
||||
for (int k = 0; k < dup; ++k)
|
||||
for (int i = 0; i < srcUniqElems; ++i)
|
||||
retVals.push_back(srcVals[i]);
|
||||
}
|
||||
|
||||
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
|
||||
Value ret = getStructFromElements(loc, retVals, rewriter, llvmStructTy);
|
||||
|
||||
rewriter.replaceOp(op, {ret});
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
|
||||
for (size_t i = 0; i < srcOffsets.size(); i++) {
|
||||
@@ -115,6 +85,7 @@ struct BroadcastOpConversion
|
||||
}
|
||||
|
||||
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
|
||||
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
|
||||
//
|
||||
#include "DotOpHelpers.h"
|
||||
#include "Utility.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
@@ -693,29 +694,103 @@ private:
|
||||
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented");
|
||||
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
auto fpw = LLVM::DotOpMmaV1ConversionHelper::fpw;
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, id] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
|
||||
Value thread = getThreadId(rewriter, loc);
|
||||
auto *ctx = thread.getContext();
|
||||
Value _1 = i32_val(1);
|
||||
Value _2 = i32_val(2);
|
||||
Value _4 = i32_val(4);
|
||||
Value _16 = i32_val(16);
|
||||
Value _32 = i32_val(32);
|
||||
Value _fpw0 = i32_val(fpw[0]);
|
||||
Value _fpw1 = i32_val(fpw[1]);
|
||||
|
||||
LLVM::DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
|
||||
LLVM::DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
|
||||
|
||||
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
|
||||
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
|
||||
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
|
||||
|
||||
Value lane = urem(thread, _32);
|
||||
Value warp = udiv(thread, _32);
|
||||
|
||||
Value warp0 = urem(warp, i32_val(wpt[0]));
|
||||
Value warp12 = udiv(warp, i32_val(wpt[0]));
|
||||
Value warp1 = urem(warp12, i32_val(wpt[1]));
|
||||
|
||||
// warp offset
|
||||
Value offWarpM = mul(warp0, i32_val(spw[0]));
|
||||
Value offWarpN = mul(warp1, i32_val(spw[1]));
|
||||
// quad offset
|
||||
Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0);
|
||||
Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1);
|
||||
// pair offset
|
||||
Value offPairM = udiv(urem(lane, _16), _4);
|
||||
offPairM = urem(offPairM, _fpw0);
|
||||
offPairM = mul(offPairM, _4);
|
||||
Value offPairN = udiv(urem(lane, _16), _4);
|
||||
offPairN = udiv(offPairN, _fpw0);
|
||||
offPairN = urem(offPairN, _fpw1);
|
||||
offPairN = mul(offPairN, _4);
|
||||
offPairM = mul(offPairM, i32_val(rep[0] / 2));
|
||||
offQuadM = mul(offQuadM, i32_val(rep[0] / 2));
|
||||
offPairN = mul(offPairN, i32_val(rep[1] / 2));
|
||||
offQuadN = mul(offQuadN, i32_val(rep[1] / 2));
|
||||
// quad pair offset
|
||||
Value offLaneM = add(offPairM, offQuadM);
|
||||
Value offLaneN = add(offPairN, offQuadN);
|
||||
// a, b offset
|
||||
Value offsetAM = add(offWarpM, offLaneM);
|
||||
Value offsetBN = add(offWarpN, offLaneN);
|
||||
// m indices
|
||||
Value offsetCM = add(and_(lane, _1), offsetAM);
|
||||
// n indices
|
||||
Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN)));
|
||||
return {offsetCM, offsetCN};
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
|
||||
for (unsigned i = 0; i < shape[0];
|
||||
i += getShapePerCTA(mmaLayout, shape)[0]) {
|
||||
for (unsigned j = 0; j < shape[1];
|
||||
j += getShapePerCTA(mmaLayout, shape)[1]) {
|
||||
ret.push_back({i, j});
|
||||
ret.push_back({i, j + 1});
|
||||
ret.push_back({i + 2, j});
|
||||
ret.push_back({i + 2, j + 1});
|
||||
ret.push_back({i, j + 8});
|
||||
ret.push_back({i, j + 9});
|
||||
ret.push_back({i + 2, j + 8});
|
||||
ret.push_back({i + 2, j + 9});
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, id] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
LLVM::DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
|
||||
LLVM::DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
auto fpw = LLVM::DotOpMmaV1ConversionHelper::fpw;
|
||||
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
|
||||
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
|
||||
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
|
||||
|
||||
SmallVector<unsigned> idxM;
|
||||
for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0])
|
||||
for (unsigned mm = 0; mm < rep[0]; ++mm)
|
||||
idxM.push_back(m + mm * 2);
|
||||
|
||||
SmallVector<unsigned> idxN;
|
||||
for (int n = 0; n < shape[1]; n += shapePerCTA[1]) {
|
||||
for (int nn = 0; nn < rep[1]; ++nn) {
|
||||
idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]);
|
||||
idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
for (unsigned x1 : idxN) { // N
|
||||
for (unsigned x0 : idxM) { // M
|
||||
SmallVector<unsigned> idx(2);
|
||||
idx[0] = x0; // M
|
||||
idx[1] = x1; // N
|
||||
ret.push_back(std::move(idx));
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -761,17 +836,9 @@ private:
|
||||
|
||||
// Emit indices calculation within each ConversionPattern, and returns a
|
||||
// [elemsPerThread X rank] index matrix.
|
||||
|
||||
// TODO: [phil] redundant indices computation do not appear to hurt
|
||||
// performance much, but they could still significantly slow down
|
||||
// computations.
|
||||
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
if (auto mmaLayout = layout.template dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(!mmaLayout.isVolta());
|
||||
}
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
|
||||
// step 2, get offset of each element
|
||||
@@ -785,7 +852,6 @@ private:
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n)
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k]));
|
||||
|
||||
return multiDimIdx;
|
||||
}
|
||||
|
||||
|
||||
@@ -29,124 +29,15 @@ struct SplatOpConversion
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
|
||||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(constVal, srcType);
|
||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(constVal, srcType);
|
||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
} else if (auto dotLayout =
|
||||
tensorTy.getEncoding()
|
||||
.dyn_cast<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
return convertSplatLikeOpWithDotOperandLayout(
|
||||
dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
|
||||
} else if (auto mmaLayout =
|
||||
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
|
||||
return convertSplatLikeOpWithMmaLayout(
|
||||
mmaLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
|
||||
} else
|
||||
assert(false && "Unsupported layout found in ConvertSplatLikeOp");
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
static Value convertSplatLikeOpWithDotOperandLayout(
|
||||
const triton::gpu::DotOperandEncodingAttr &layout, Type resType,
|
||||
Type elemType, Value constVal, TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
auto dotOperand =
|
||||
tensorTy.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto parent = layout.getParent();
|
||||
Value retVal = constVal;
|
||||
Type retTy = elemType;
|
||||
int numElems{};
|
||||
if (auto mmaLayout = parent.dyn_cast<MmaEncodingAttr>()) {
|
||||
Type matTy;
|
||||
if (mmaLayout.isAmpere()) {
|
||||
numElems = layout.getOpIdx() == 0
|
||||
? MMA16816ConversionHelper::getANumElemsPerThread(
|
||||
tensorTy, mmaLayout.getWarpsPerCTA()[0])
|
||||
: MMA16816ConversionHelper::getBNumElemsPerThread(
|
||||
tensorTy, mmaLayout.getWarpsPerCTA()[1]);
|
||||
DotOpMmaV2ConversionHelper helper(mmaLayout);
|
||||
helper.deduceMmaType(tensorTy);
|
||||
matTy = helper.getMatType();
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
bool isRow = layout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, _0] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
if (layout.getOpIdx() == 0) {
|
||||
DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
|
||||
numElems =
|
||||
helper.numElemsPerThreadA(shape, isARow, isAVec4, aParam.vec);
|
||||
} else {
|
||||
DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
|
||||
numElems =
|
||||
helper.numElemsPerThreadB(shape, isBRow, isBVec4, bParam.vec);
|
||||
}
|
||||
matTy = helper.getMatType(tensorTy);
|
||||
}
|
||||
|
||||
auto numPackedElems = matTy.cast<LLVM::LLVMStructType>()
|
||||
.getBody()[0]
|
||||
.cast<VectorType>()
|
||||
.getNumElements();
|
||||
retTy = vec_ty(elemType, numPackedElems);
|
||||
retVal = undef(retTy);
|
||||
for (auto i = 0; i < numPackedElems; ++i) {
|
||||
retVal = insert_element(retTy, retVal, constVal, i32_val(i));
|
||||
}
|
||||
|
||||
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
|
||||
} else {
|
||||
assert(false && "Unsupported layout found");
|
||||
}
|
||||
|
||||
auto structTy = LLVM::LLVMStructType::getLiteral(
|
||||
rewriter.getContext(), SmallVector<Type>(numElems, retTy));
|
||||
return getStructFromElements(loc, SmallVector<Value>(numElems, retVal),
|
||||
rewriter, structTy);
|
||||
}
|
||||
|
||||
static Value convertSplatLikeOpWithMmaLayout(
|
||||
const MmaEncodingAttr &layout, Type resType, Type elemType,
|
||||
Value constVal, TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
if (layout.isAmpere()) {
|
||||
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy);
|
||||
size_t fcSize = 4 * repM * repN;
|
||||
|
||||
auto structTy = LLVM::LLVMStructType::getLiteral(
|
||||
rewriter.getContext(), SmallVector<Type>(fcSize, elemType));
|
||||
return getStructFromElements(loc, SmallVector<Value>(fcSize, constVal),
|
||||
rewriter, structTy);
|
||||
}
|
||||
if (layout.isVolta()) {
|
||||
DotOpMmaV1ConversionHelper helper(layout);
|
||||
int repM = helper.getRepM(shape[0]);
|
||||
int repN = helper.getRepN(shape[1]);
|
||||
// According to mma layout of v1, each thread process 8 elements.
|
||||
int elems = 8 * repM * repN;
|
||||
|
||||
auto structTy = LLVM::LLVMStructType::getLiteral(
|
||||
rewriter.getContext(), SmallVector<Type>(elems, elemType));
|
||||
return getStructFromElements(loc, SmallVector<Value>(elems, constVal),
|
||||
rewriter, structTy);
|
||||
}
|
||||
|
||||
assert(false && "Unsupported mma layout found");
|
||||
return {};
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
|
||||
@@ -254,6 +145,7 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
|
||||
Value view = getStructFromElements(loc, vals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
|
||||
@@ -378,11 +378,19 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
|
||||
int res = 0;
|
||||
if (isVolta()) {
|
||||
unsigned mmasRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]);
|
||||
unsigned mmasCol = ceil<unsigned>(shape[1], 16 * getWarpsPerCTA()[1]);
|
||||
// Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16
|
||||
// matrix as result.
|
||||
res = mmasRow * mmasCol * (16 * 16 / 32);
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, id] = decodeVoltaLayoutStates();
|
||||
static constexpr std::array<unsigned, 2> fpw{{2, 2}};
|
||||
unsigned packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
unsigned packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
unsigned repM = 2 * packSize0;
|
||||
unsigned repN = 2 * packSize1;
|
||||
unsigned spwM = fpw[0] * 4 * repM;
|
||||
unsigned spwN = fpw[1] * 4 * repN;
|
||||
unsigned wptM = getWarpsPerCTA()[0];
|
||||
unsigned wptN = getWarpsPerCTA()[1];
|
||||
unsigned resM = repM * std::max<int>(1, shape[0] / (spwM * wptM));
|
||||
unsigned resN = 2 * repN * std::max<int>(1, shape[1] / (spwN * wptN));
|
||||
res = resM * resN;
|
||||
} else if (isAmpere()) {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||
|
||||
@@ -601,7 +601,8 @@ public:
|
||||
return failure();
|
||||
}
|
||||
// don't rematerialize non-element-wise
|
||||
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
if (!isa<triton::ViewOp, triton::CatOp>(op) &&
|
||||
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||
!isa<triton::StoreOp>(op)) {
|
||||
return failure();
|
||||
@@ -865,8 +866,10 @@ public:
|
||||
return failure();
|
||||
|
||||
for (Operation *op : cvtSlices) {
|
||||
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultType>())
|
||||
if (!isa<triton::ViewOp, triton::CatOp>(op) &&
|
||||
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||
!isa<triton::StoreOp>(op))
|
||||
return failure();
|
||||
for (Value arg : op->getOperands()) {
|
||||
Operation *argOp = arg.getDefiningOp();
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
@@ -16,297 +18,155 @@ using triton::gpu::MmaEncodingAttr;
|
||||
using triton::gpu::SharedEncodingAttr;
|
||||
using triton::gpu::SliceEncodingAttr;
|
||||
|
||||
// This pattern collects the wrong Mma those need to update and create the right
|
||||
// ones for each.
|
||||
// TODO[Superjomn]: RewirtePattern is not needed here, Rewrite this to a method
|
||||
class CollectMmaToUpdateForVolta : public mlir::RewritePattern {
|
||||
// Holds the mapping from old(wrong) mmaEncodingAttr to the new(correct)
|
||||
// mmaEncodingAttr.
|
||||
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
|
||||
// Get the wpt for MMAv1 using more information.
|
||||
// Reference the original logic here
|
||||
// https://github.com/openai/triton/blob/0e4691e6dd91e001a8d33b71badf8b3314325459/lib/codegen/analysis/layout.cc#L223
|
||||
SmallVector<unsigned> getWarpsPerCTA(ArrayRef<int64_t> shape, bool isARow,
|
||||
bool isBRow, bool isAVec4, bool isBVec4,
|
||||
int numWarps) {
|
||||
// TODO[Superjomn]: Share code with
|
||||
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
|
||||
// rep,spw and fpw.
|
||||
SmallVector<unsigned> wpt({1, 1});
|
||||
SmallVector<unsigned> wpt_nm1;
|
||||
|
||||
public:
|
||||
CollectMmaToUpdateForVolta(
|
||||
mlir::MLIRContext *ctx,
|
||||
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx),
|
||||
mmaToUpdate(mmaToUpdate) {}
|
||||
SmallVector<int, 2> rep(2), spw(2);
|
||||
std::array<int, 3> fpw{{2, 2, 1}};
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
rep[0] = 2 * packSize0;
|
||||
spw[0] = fpw[0] * 4 * rep[0];
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
auto *ctx = dotOp->getContext();
|
||||
auto AT = dotOp.a().getType().cast<RankedTensorType>();
|
||||
auto BT = dotOp.b().getType().cast<RankedTensorType>();
|
||||
auto DT = dotOp.d().getType().cast<RankedTensorType>();
|
||||
auto shapeA = AT.getShape();
|
||||
auto shapeB = BT.getShape();
|
||||
if (!DT.getEncoding())
|
||||
return failure();
|
||||
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||
if (!(mmaLayout && mmaLayout.isVolta()))
|
||||
return failure();
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
rep[1] = 2 * packSize1;
|
||||
spw[1] = fpw[1] * 4 * rep[1];
|
||||
|
||||
// Has processed.
|
||||
if (mmaToUpdate.count(mmaLayout))
|
||||
return failure();
|
||||
do {
|
||||
wpt_nm1 = wpt;
|
||||
if (wpt[0] * wpt[1] < numWarps)
|
||||
wpt[0] = std::clamp<int>(wpt[0] * 2, 1, shape[0] / spw[0]);
|
||||
if (wpt[0] * wpt[1] < numWarps)
|
||||
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shape[1] / spw[1]);
|
||||
} while (wpt_nm1 != wpt);
|
||||
|
||||
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
return wpt;
|
||||
}
|
||||
|
||||
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
|
||||
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
|
||||
// Given a (potentially malformed) DotOp, determines the optimal
|
||||
// MMAEncoding to use on V100
|
||||
LogicalResult getOptimizedV100MMaLayout(triton::DotOp dotOp,
|
||||
MmaEncodingAttr &old,
|
||||
MmaEncodingAttr &ret) {
|
||||
auto *ctx = dotOp->getContext();
|
||||
auto AT = dotOp.a().getType().cast<RankedTensorType>();
|
||||
auto BT = dotOp.b().getType().cast<RankedTensorType>();
|
||||
auto DT = dotOp.d().getType().cast<RankedTensorType>();
|
||||
auto shapeA = AT.getShape();
|
||||
auto shapeB = BT.getShape();
|
||||
if (!DT.getEncoding())
|
||||
return mlir::failure();
|
||||
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||
if (!(mmaLayout && mmaLayout.isVolta()))
|
||||
return mlir::failure();
|
||||
// We have an MmaEncodingAttr here. Find the correct layout for it.
|
||||
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
|
||||
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
|
||||
// The wpt of MMAv1 is also determined by isARow, isBRow and shape, and it
|
||||
// could only be set here for those states might be updated by previous
|
||||
// patterns in the Combine Pass.
|
||||
auto tgtWpt = getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
|
||||
product(mmaLayout.getWarpsPerCTA()));
|
||||
if (isARow == isARow_ && isBRow == isBRow_ && isAVec4 == isAVec4_ &&
|
||||
isBVec4 == isBVec4_) {
|
||||
if (tgtWpt == mmaLayout.getWarpsPerCTA())
|
||||
return mlir::failure();
|
||||
}
|
||||
// Recalculate the wpt, for here we could get the latest information, the
|
||||
// wpt should be updated.
|
||||
auto updatedWpt =
|
||||
getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
|
||||
product(mmaLayout.getWarpsPerCTA()));
|
||||
// return results
|
||||
old = mmaLayout;
|
||||
ret =
|
||||
MmaEncodingAttr::get(ctx, mmaLayout.getVersionMajor(), updatedWpt,
|
||||
AT.getShape(), BT.getShape(), isARow, isBRow, mmaId);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// The wpt of MMAv1 is also determined by isARow, isBRow and shape, and it
|
||||
// could only be set here for those states might be updated by previous
|
||||
// patterns in the Combine Pass.
|
||||
auto tgtWpt = getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4,
|
||||
isBVec4, product(mmaLayout.getWarpsPerCTA()));
|
||||
if (isARow == isARow_ && isBRow == isBRow_ && isAVec4 == isAVec4_ &&
|
||||
isBVec4 == isBVec4_) {
|
||||
if (tgtWpt == mmaLayout.getWarpsPerCTA())
|
||||
return failure();
|
||||
// Replace result op type
|
||||
void setOpResultType(Operation *op, ArrayRef<Type> newTypes) {
|
||||
if (op->getNumResults() != newTypes.size())
|
||||
llvm_unreachable("number of types different from number of results");
|
||||
// nothing to do
|
||||
if (op->getNumResults() == 0)
|
||||
return;
|
||||
// replace types
|
||||
for (unsigned i = 0; i < op->getNumResults(); i++) {
|
||||
Type newType = newTypes[i];
|
||||
op->getResult(i).setType(newType);
|
||||
}
|
||||
// special case: arith.constant: we need to change the value attr
|
||||
if (isa<arith::ConstantOp>(op)) {
|
||||
Type newType = newTypes[0];
|
||||
auto attr = op->getAttrDictionary()
|
||||
.get("value")
|
||||
.dyn_cast<mlir::DenseElementsAttr>();
|
||||
if (attr) {
|
||||
auto newAttr = mlir::DenseElementsAttr::getFromRawBuffer(
|
||||
newType, attr.getRawData(), true);
|
||||
op->setAttr("value", newAttr);
|
||||
}
|
||||
|
||||
MmaEncodingAttr newMmaLayout;
|
||||
{
|
||||
// Recalculate the wpt, for here we could get the latest information, the
|
||||
// wpt should be updated.
|
||||
auto updatedWpt =
|
||||
getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
|
||||
product(mmaLayout.getWarpsPerCTA()));
|
||||
|
||||
newMmaLayout = MmaEncodingAttr::get(ctx, mmaLayout.getVersionMajor(),
|
||||
updatedWpt, AT.getShape(),
|
||||
BT.getShape(), isARow, isBRow, mmaId);
|
||||
}
|
||||
|
||||
// Collect the wrong MMA Layouts, and mark need to update.
|
||||
mmaToUpdate.try_emplace(mmaLayout, newMmaLayout);
|
||||
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Get the wpt for MMAv1 using more information.
|
||||
// Reference the original logic here
|
||||
// https://github.com/openai/triton/blob/0e4691e6dd91e001a8d33b71badf8b3314325459/lib/codegen/analysis/layout.cc#L223
|
||||
SmallVector<unsigned> getWarpsPerCTA(ArrayRef<int64_t> shape, bool isARow,
|
||||
bool isBRow, bool isAVec4, bool isBVec4,
|
||||
int numWarps) const {
|
||||
// TODO[Superjomn]: Share code with
|
||||
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
|
||||
// rep,spw and fpw.
|
||||
SmallVector<unsigned> wpt({1, 1});
|
||||
SmallVector<unsigned> wpt_nm1;
|
||||
|
||||
SmallVector<int, 2> rep(2), spw(2);
|
||||
std::array<int, 3> fpw{{2, 2, 1}};
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
rep[0] = 2 * packSize0;
|
||||
spw[0] = fpw[0] * 4 * rep[0];
|
||||
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
rep[1] = 2 * packSize1;
|
||||
spw[1] = fpw[1] * 4 * rep[1];
|
||||
|
||||
do {
|
||||
wpt_nm1 = wpt;
|
||||
if (wpt[0] * wpt[1] < numWarps)
|
||||
wpt[0] = std::clamp<int>(wpt[0] * 2, 1, shape[0] / spw[0]);
|
||||
if (wpt[0] * wpt[1] < numWarps)
|
||||
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shape[1] / spw[1]);
|
||||
} while (wpt_nm1 != wpt);
|
||||
|
||||
return wpt;
|
||||
// update style type given the provided layoutMap
|
||||
Type updateStaleType(
|
||||
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &layoutMap,
|
||||
RankedTensorType type) {
|
||||
auto encoding = type.getEncoding();
|
||||
// mma encoding
|
||||
if (auto mma = encoding.dyn_cast<MmaEncodingAttr>()) {
|
||||
auto newMma = layoutMap.lookup(mma);
|
||||
if (!newMma)
|
||||
return Type();
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
newMma);
|
||||
}
|
||||
};
|
||||
|
||||
class UpdateMMAForMMAv1 : public mlir::RewritePattern {
|
||||
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
|
||||
|
||||
public:
|
||||
UpdateMMAForMMAv1(
|
||||
MLIRContext *context,
|
||||
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
|
||||
: RewritePattern(MatchAnyOpTypeTag{}, 1, context),
|
||||
mmaToUpdate(mmaToUpdate) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Nothing to update
|
||||
if (mmaToUpdate.empty())
|
||||
return failure();
|
||||
|
||||
if (auto dotOp = llvm::dyn_cast<DotOp>(op))
|
||||
return rewriteDotOp(op, rewriter);
|
||||
else if (auto cvtOp = llvm::dyn_cast<ConvertLayoutOp>(op))
|
||||
return rewriteCvtOp(op, rewriter);
|
||||
else if (auto expandDimsOp = llvm::dyn_cast<triton::ExpandDimsOp>(op))
|
||||
return rewriteExpandDimsOp(op, rewriter);
|
||||
else if (auto constOp = llvm::dyn_cast<arith::ConstantOp>(op))
|
||||
return rewriteConstantOp(op, rewriter);
|
||||
else
|
||||
return rewriteElementwiseOp(op, rewriter);
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult rewriteDotOp(Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
auto dotOp = llvm::cast<DotOp>(op);
|
||||
auto tensorTy = dotOp->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return failure();
|
||||
|
||||
auto mma = dotOp.d()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.dyn_cast<MmaEncodingAttr>();
|
||||
if (!mma || !mmaToUpdate.count(mma))
|
||||
return failure();
|
||||
|
||||
auto newTensorTy = getUpdatedType(tensorTy);
|
||||
rewriter.replaceOpWithNewOp<DotOp>(op, newTensorTy, dotOp.a(), dotOp.b(),
|
||||
dotOp.c(), dotOp.allowTF32());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult rewriteCvtOp(Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
auto cvt = llvm::cast<ConvertLayoutOp>(op);
|
||||
if (!needUpdate(cvt.getResult().getType()))
|
||||
return failure();
|
||||
auto tensorTy = cvt.result().getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
auto newTensorTy = getUpdatedType(tensorTy);
|
||||
auto newOp = rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
|
||||
cvt.getOperand());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult rewriteExpandDimsOp(Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
auto expandDims = llvm::cast<triton::ExpandDimsOp>(op);
|
||||
auto srcTy = expandDims.src().getType();
|
||||
auto resTy = expandDims.getResult().getType();
|
||||
|
||||
// the result type need to update
|
||||
if (!needUpdate(srcTy) && needUpdate(resTy)) {
|
||||
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(op, expandDims.src(),
|
||||
expandDims.axis());
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult rewriteConstantOp(Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
auto constant = llvm::cast<arith::ConstantOp>(op);
|
||||
auto resTy = constant.getResult().getType();
|
||||
if (!needUpdate(resTy))
|
||||
return failure();
|
||||
|
||||
auto tensorTy = constant.getResult().getType().cast<RankedTensorType>();
|
||||
auto mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||
auto dot = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
||||
if (!mma && !dot)
|
||||
return failure();
|
||||
|
||||
auto newTensorTy = getUpdatedType(tensorTy);
|
||||
if (auto attr = constant.getValue().dyn_cast<SplatElementsAttr>()) {
|
||||
auto newRet =
|
||||
SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult rewriteElementwiseOp(Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
if (op->getNumOperands() != 1 || op->getNumResults() != 1)
|
||||
return failure();
|
||||
|
||||
auto srcTy = op->getOperand(0).getType();
|
||||
auto resTy = op->getResult(0).getType();
|
||||
|
||||
if (needUpdate(resTy)) {
|
||||
// The op-inputs' types are not necessary to update, for some
|
||||
// replaceOpWithNewOp will help update them.
|
||||
op->getResult(0).setType(
|
||||
getUpdatedType(resTy.dyn_cast<RankedTensorType>()));
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
RankedTensorType getUpdatedType(RankedTensorType type) const {
|
||||
if (!needUpdate(type))
|
||||
return type;
|
||||
auto encoding = type.getEncoding();
|
||||
if (auto mma = encoding.dyn_cast<MmaEncodingAttr>()) {
|
||||
auto newMma = mmaToUpdate.lookup(mma);
|
||||
// slice encoding
|
||||
else if (auto slice = encoding.dyn_cast<SliceEncodingAttr>()) {
|
||||
if (auto mma = slice.getParent().dyn_cast<MmaEncodingAttr>()) {
|
||||
auto newMma = layoutMap.lookup(mma);
|
||||
if (!newMma)
|
||||
return Type();
|
||||
auto newSlice =
|
||||
SliceEncodingAttr::get(slice.getContext(), slice.getDim(), newMma);
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
newMma);
|
||||
} else if (auto slice = encoding.dyn_cast<SliceEncodingAttr>()) {
|
||||
if (auto mma = slice.getParent().dyn_cast<MmaEncodingAttr>()) {
|
||||
auto newMma = mmaToUpdate.lookup(mma);
|
||||
auto newSlice =
|
||||
SliceEncodingAttr::get(slice.getContext(), slice.getDim(), newMma);
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
newSlice);
|
||||
}
|
||||
} else if (auto dotOp = encoding.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
if (auto mma = dotOp.getParent().dyn_cast<MmaEncodingAttr>()) {
|
||||
auto newMma = mmaToUpdate.lookup(mma);
|
||||
auto newDotOp =
|
||||
DotOperandEncodingAttr::get(dotOp.getContext(), dotOp.getOpIdx(),
|
||||
newMma, dotOp.getIsMMAv1Row());
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
newDotOp);
|
||||
}
|
||||
newSlice);
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
// Tell if this type contains a wrong MMA encoding and need to update.
|
||||
bool needUpdate(Type type) const {
|
||||
auto tensorTy = type.dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return false;
|
||||
return needUpdate(tensorTy);
|
||||
}
|
||||
|
||||
// Tell if this type contains a wrong MMA encoding and need to update.
|
||||
bool needUpdate(RankedTensorType type) const {
|
||||
auto encoding = type.getEncoding();
|
||||
if (!encoding)
|
||||
return false;
|
||||
|
||||
MmaEncodingAttr mma;
|
||||
if ((mma = encoding.dyn_cast<MmaEncodingAttr>())) {
|
||||
} else if (auto slice = encoding.dyn_cast<SliceEncodingAttr>()) {
|
||||
mma = slice.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
} else if (auto dotOp = encoding.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
mma = dotOp.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
// dot operand encoding
|
||||
else if (auto dotOp = encoding.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
if (auto mma = dotOp.getParent().dyn_cast<MmaEncodingAttr>()) {
|
||||
auto newMma = layoutMap.lookup(mma);
|
||||
if (!newMma)
|
||||
return Type();
|
||||
auto newDotOp = DotOperandEncodingAttr::get(
|
||||
dotOp.getContext(), dotOp.getOpIdx(), newMma, dotOp.getIsMMAv1Row());
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
newDotOp);
|
||||
}
|
||||
|
||||
return mma && mmaToUpdate.count(mma);
|
||||
}
|
||||
};
|
||||
return Type();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class UpdateMmaForVoltaPass
|
||||
: public UpdateMmaForVoltaBase<UpdateMmaForVoltaPass> {
|
||||
public:
|
||||
@@ -314,34 +174,34 @@ public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
llvm::DenseMap<MmaEncodingAttr, MmaEncodingAttr> mmaToUpdate;
|
||||
{
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<CollectMmaToUpdateForVolta>(context, mmaToUpdate);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.enableRegionSimplification =
|
||||
false; // The pattern doesn't modify the IR
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
if (!mmaToUpdate.empty()) {
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<UpdateMMAForMMAv1>(context, mmaToUpdate);
|
||||
|
||||
mlir::GreedyRewriteConfig config;
|
||||
// Make sure the slice and dot_operand layouts' parent mma are updated
|
||||
// before updating DotOp or it will get a mismatch parent-encoding.
|
||||
config.useTopDownTraversal = true;
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed())
|
||||
signalPassFailure();
|
||||
|
||||
if (fixupLoops(m).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
// Step 1:
|
||||
// Build a map from old MMA encoding to new MMA encoding.
|
||||
DenseMap<MmaEncodingAttr, MmaEncodingAttr> layoutMap;
|
||||
m.walk([&layoutMap](triton::DotOp dotOp) {
|
||||
MmaEncodingAttr newLayout;
|
||||
MmaEncodingAttr oldLayout;
|
||||
if (failed(getOptimizedV100MMaLayout(dotOp, oldLayout, newLayout)))
|
||||
return;
|
||||
layoutMap[oldLayout] = newLayout;
|
||||
});
|
||||
// Step 2:
|
||||
// Replace all wrong layouts with the right one
|
||||
m.walk([&layoutMap](Operation *op) {
|
||||
if (op->getNumResults() != 1)
|
||||
return;
|
||||
auto type = op->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return;
|
||||
Type newType = updateStaleType(layoutMap, type);
|
||||
if (!newType)
|
||||
return;
|
||||
setOpResultType(op, {newType});
|
||||
});
|
||||
// Step 3:
|
||||
// We may have messed up some loops in the process.
|
||||
// Fix them up
|
||||
if (fixupLoops(m).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1240,23 +1240,24 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
def test_dot_without_load(dtype_str):
|
||||
@triton.jit
|
||||
def _kernel(out):
|
||||
a = GENERATE_TEST_HERE
|
||||
b = GENERATE_TEST_HERE
|
||||
c = tl.dot(a, b)
|
||||
out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
tl.store(out_ptr, c)
|
||||
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
||||
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
# def test_dot_without_load(dtype_str):
|
||||
# @triton.jit
|
||||
# def _kernel(out):
|
||||
# a = GENERATE_TEST_HERE
|
||||
# b = GENERATE_TEST_HERE
|
||||
# c = tl.dot(a, b)
|
||||
# out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
# tl.store(out_ptr, c)
|
||||
|
||||
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
|
||||
a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
out_ref = torch.matmul(a, b)
|
||||
out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
kernel[(1,)](out)
|
||||
assert torch.all(out == out_ref)
|
||||
# kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
|
||||
# a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# out_ref = torch.matmul(a, b)
|
||||
# out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# kernel[(1,)](out)
|
||||
# assert torch.all(out == out_ref)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
|
||||
Reference in New Issue
Block a user