mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Use getOrder for mma layout warps order instead of the hardcoded col-major order (#1825)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::LLVM::linearize;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
@@ -143,9 +145,10 @@ private:
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
// TODO: fix the bug in MMAEncodingAttr document
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
multiDimWarpId[0] = urem(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
auto order = triton::gpu::getOrder(mmaLayout);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
Value _1 = i32_val(1);
|
||||
Value _2 = i32_val(2);
|
||||
Value _4 = i32_val(4);
|
||||
|
||||
Reference in New Issue
Block a user