mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Use getOrder for mma layout warps order instead of the hardcoded col-major order (#1825)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::LLVM::linearize;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
@@ -143,9 +145,10 @@ private:
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
// TODO: fix the bug in MMAEncodingAttr document
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
multiDimWarpId[0] = urem(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
auto order = triton::gpu::getOrder(mmaLayout);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
Value _1 = i32_val(1);
|
||||
Value _2 = i32_val(2);
|
||||
Value _4 = i32_val(4);
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
using namespace mlir;
|
||||
|
||||
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
@@ -552,14 +553,17 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
|
||||
int kWidth = encoding.getMMAv2kWidth();
|
||||
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
auto order = triton::gpu::getOrder(mmaLayout);
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value lane = urem(thread, i32_val(32));
|
||||
// Note: warps are currently column major in MMA layout
|
||||
Value warpRowIndex = urem(warp, i32_val(warpsPerCTA[0]));
|
||||
Value warpColIndex =
|
||||
urem(udiv(warp, i32_val(warpsPerCTA[0])), i32_val(warpsPerCTA[1]));
|
||||
Value warpM = urem(warpRowIndex, i32_val(shape[0] / 16));
|
||||
Value warpN = urem(warpColIndex, i32_val(shape[1] / 8));
|
||||
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warp, warpsPerCTA, order);
|
||||
unsigned lastAxis = order[order.size() - 1];
|
||||
multiDimWarpId[lastAxis] =
|
||||
urem(multiDimWarpId[lastAxis], i32_val(warpsPerCTA[lastAxis]));
|
||||
Value warpM = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
|
||||
Value warpN = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
|
||||
|
||||
int warpsPerTile;
|
||||
if (isA)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
#include "ReduceOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::linearize;
|
||||
using ::mlir::LLVM::shflSync;
|
||||
using ::mlir::LLVM::storeShared;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::SharedMemoryObject;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
@@ -504,65 +505,6 @@ public:
|
||||
return mask;
|
||||
}
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) const {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank == order.size());
|
||||
auto reordered = reorder(shape, order);
|
||||
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
multiDim[order[i]] = reorderedMultiDim[i];
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape) const {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank > 0);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
if (rank == 1) {
|
||||
multiDim[0] = linear;
|
||||
} else {
|
||||
Value remained = linear;
|
||||
for (auto &&en : llvm::enumerate(shape.drop_back())) {
|
||||
Value dimSize = i32_val(en.value());
|
||||
multiDim[en.index()] = urem(remained, dimSize);
|
||||
remained = udiv(remained, dimSize);
|
||||
}
|
||||
multiDim[rank - 1] = remained;
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) const {
|
||||
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
|
||||
reorder<unsigned>(shape, order));
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
||||
auto rank = multiDim.size();
|
||||
Value linear = i32_val(0);
|
||||
if (rank > 0) {
|
||||
linear = multiDim.back();
|
||||
for (auto [dim, dimShape] :
|
||||
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
|
||||
Value dimSize = i32_val(dimShape);
|
||||
linear = add(mul(linear, dimSize), dim);
|
||||
}
|
||||
}
|
||||
return linear;
|
||||
}
|
||||
|
||||
Value dot(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
|
||||
assert(offsets.size() == strides.size());
|
||||
@@ -927,19 +869,23 @@ private:
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
assert(_warpsPerCTA.size() == 2);
|
||||
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
|
||||
i32_val(_warpsPerCTA[1])};
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
assert(warpsPerCTA.size() == 2);
|
||||
auto order = triton::gpu::getOrder(mmaLayout);
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / 16));
|
||||
Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]),
|
||||
i32_val(shape[1] / 8));
|
||||
Value offWarp0 = mul(warpId0, i32_val(16));
|
||||
Value offWarp1 = mul(warpId1, i32_val(8));
|
||||
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
unsigned lastAxis = order[order.size() - 1];
|
||||
multiDimWarpId[lastAxis] =
|
||||
urem(multiDimWarpId[lastAxis], i32_val(warpsPerCTA[lastAxis]));
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
|
||||
Value offWarp0 = mul(multiDimWarpId[0], i32_val(16));
|
||||
Value offWarp1 = mul(multiDimWarpId[1], i32_val(8));
|
||||
|
||||
SmallVector<Value> multiDimBase(2);
|
||||
multiDimBase[0] = add(udiv(laneId, i32_val(4)), offWarp0);
|
||||
|
||||
@@ -70,6 +70,65 @@ getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
|
||||
return strides;
|
||||
}
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank == order.size());
|
||||
auto reordered = reorder(shape, order);
|
||||
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
multiDim[order[i]] = reorderedMultiDim[i];
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape) {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank > 0);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
if (rank == 1) {
|
||||
multiDim[0] = linear;
|
||||
} else {
|
||||
Value remained = linear;
|
||||
for (auto &&en : llvm::enumerate(shape.drop_back())) {
|
||||
Value dimSize = i32_val(en.value());
|
||||
multiDim[en.index()] = urem(remained, dimSize);
|
||||
remained = udiv(remained, dimSize);
|
||||
}
|
||||
multiDim[rank - 1] = remained;
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) {
|
||||
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
|
||||
reorder<unsigned>(shape, order));
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) {
|
||||
auto rank = multiDim.size();
|
||||
Value linear = i32_val(0);
|
||||
if (rank > 0) {
|
||||
linear = multiDim.back();
|
||||
for (auto [dim, dimShape] :
|
||||
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
|
||||
Value dimSize = i32_val(dimShape);
|
||||
linear = add(mul(linear, dimSize), dim);
|
||||
}
|
||||
}
|
||||
return linear;
|
||||
}
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
@@ -258,6 +258,24 @@ SharedMemoryObject
|
||||
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order);
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order);
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape);
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred);
|
||||
|
||||
|
||||
@@ -1497,7 +1497,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device):
|
||||
layouts = [
|
||||
BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[2, 2])
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user