[BACKEND] Use getOrder for mma layout warps order instead of the hardcoded col-major order (#1825)

This commit is contained in:
Zahi Moudallal
2023-06-27 10:56:09 -07:00
committed by GitHub
parent d4c941177e
commit 2dcbf4783e
7 changed files with 112 additions and 78 deletions

View File

@@ -1,8 +1,10 @@
#include "ConvertLayoutOpToLLVM.h"
#include "Utility.h"
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
@@ -143,9 +145,10 @@ private:
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto order = triton::gpu::getOrder(mmaLayout);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
Value _1 = i32_val(1);
Value _2 = i32_val(2);
Value _4 = i32_val(4);

View File

@@ -4,6 +4,7 @@
using namespace mlir;
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
@@ -552,14 +553,17 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
int kWidth = encoding.getMMAv2kWidth();
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto order = triton::gpu::getOrder(mmaLayout);
Value warp = udiv(thread, i32_val(32));
Value lane = urem(thread, i32_val(32));
// Note: warps are currently column major in MMA layout
Value warpRowIndex = urem(warp, i32_val(warpsPerCTA[0]));
Value warpColIndex =
urem(udiv(warp, i32_val(warpsPerCTA[0])), i32_val(warpsPerCTA[1]));
Value warpM = urem(warpRowIndex, i32_val(shape[0] / 16));
Value warpN = urem(warpColIndex, i32_val(shape[1] / 8));
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warp, warpsPerCTA, order);
unsigned lastAxis = order[order.size() - 1];
multiDimWarpId[lastAxis] =
urem(multiDimWarpId[lastAxis], i32_val(warpsPerCTA[lastAxis]));
Value warpM = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
Value warpN = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
int warpsPerTile;
if (isA)

View File

@@ -1,8 +1,11 @@
#include "ReduceOpToLLVM.h"
#include "Utility.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::linearize;
using ::mlir::LLVM::shflSync;
using ::mlir::LLVM::storeShared;
using ::mlir::triton::gpu::getOrder;

View File

@@ -15,6 +15,7 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
@@ -504,65 +505,6 @@ public:
return mask;
}
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) const {
unsigned rank = shape.size();
assert(rank == order.size());
auto reordered = reorder(shape, order);
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
SmallVector<Value> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape) const {
unsigned rank = shape.size();
assert(rank > 0);
SmallVector<Value> multiDim(rank);
if (rank == 1) {
multiDim[0] = linear;
} else {
Value remained = linear;
for (auto &&en : llvm::enumerate(shape.drop_back())) {
Value dimSize = i32_val(en.value());
multiDim[en.index()] = urem(remained, dimSize);
remained = udiv(remained, dimSize);
}
multiDim[rank - 1] = remained;
}
return multiDim;
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) const {
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
reorder<unsigned>(shape, order));
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
auto rank = multiDim.size();
Value linear = i32_val(0);
if (rank > 0) {
linear = multiDim.back();
for (auto [dim, dimShape] :
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
Value dimSize = i32_val(dimShape);
linear = add(mul(linear, dimSize), dim);
}
}
return linear;
}
Value dot(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
assert(offsets.size() == strides.size());
@@ -927,19 +869,23 @@ private:
const MmaEncodingAttr &mmaLayout,
RankedTensorType type) const {
auto shape = type.getShape();
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
assert(warpsPerCTA.size() == 2);
auto order = triton::gpu::getOrder(mmaLayout);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / 16));
Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]),
i32_val(shape[1] / 8));
Value offWarp0 = mul(warpId0, i32_val(16));
Value offWarp1 = mul(warpId1, i32_val(8));
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
unsigned lastAxis = order[order.size() - 1];
multiDimWarpId[lastAxis] =
urem(multiDimWarpId[lastAxis], i32_val(warpsPerCTA[lastAxis]));
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
Value offWarp0 = mul(multiDimWarpId[0], i32_val(16));
Value offWarp1 = mul(multiDimWarpId[1], i32_val(8));
SmallVector<Value> multiDimBase(2);
multiDimBase[0] = add(udiv(laneId, i32_val(4)), offWarp0);

View File

@@ -70,6 +70,65 @@ getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
return strides;
}
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
unsigned rank = shape.size();
assert(rank == order.size());
auto reordered = reorder(shape, order);
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
SmallVector<Value> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape) {
unsigned rank = shape.size();
assert(rank > 0);
SmallVector<Value> multiDim(rank);
if (rank == 1) {
multiDim[0] = linear;
} else {
Value remained = linear;
for (auto &&en : llvm::enumerate(shape.drop_back())) {
Value dimSize = i32_val(en.value());
multiDim[en.index()] = urem(remained, dimSize);
remained = udiv(remained, dimSize);
}
multiDim[rank - 1] = remained;
}
return multiDim;
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
reorder<unsigned>(shape, order));
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) {
auto rank = multiDim.size();
Value linear = i32_val(0);
if (rank > 0) {
linear = multiDim.back();
for (auto [dim, dimShape] :
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
Value dimSize = i32_val(dimShape);
linear = add(mul(linear, dimSize), dim);
}
}
return linear;
}
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext();

View File

@@ -258,6 +258,24 @@ SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter);
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape);
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape);
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value val, Value pred);

View File

@@ -1497,7 +1497,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device):
layouts = [
BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]),
BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
MmaLayout(version=(2, 0), warps_per_cta=[2, 2])
]