mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge branch 'triton-mlir' into ifu230601
This commit is contained in:
@@ -166,18 +166,19 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
// TODO make macro tile size granilarity configurable
|
||||
int macroTileM =
|
||||
std::max<int>(shape[0] / (mfmaLayout.getWarpsPerCTA()[0] * 32), 1);
|
||||
int wptM = std::min<int>(mfmaLayout.getWarpsPerCTA()[0], macroTileM);
|
||||
std::max<int>(shape[0] / (warpsPerCTA[0] * 32), 1);
|
||||
int wptM = std::min<int>(warpsPerCTA[0], macroTileM);
|
||||
int macroTileN =
|
||||
std::max<int>(shape[1] / (mfmaLayout.getWarpsPerCTA()[1] * 32), 1);
|
||||
int wptN = std::min<int>(mfmaLayout.getWarpsPerCTA()[1], macroTileN);
|
||||
std::max<int>(shape[1] / (warpsPerCTA[1] * 32), 1);
|
||||
int wptN = std::min<int>(warpsPerCTA[1], macroTileN);
|
||||
int wpt = std::max<int>(wptM, wptN);
|
||||
|
||||
SmallVector<Value> offsets;
|
||||
if (isTransposed(order)) {
|
||||
SmallVector<int64_t> elemsPerThread{aElemsPerThread[1], aElemsPerThread[0]};
|
||||
SmallVector<int64_t> reps{numReps[1], numReps[0]};
|
||||
int warpsPerGroupM = mfmaLayout.getWarpsPerCTA()[0];
|
||||
unsigned int maxNumWarps = shape[0] / mfmaInstrM;
|
||||
int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps);
|
||||
offsets =
|
||||
computeOffsetsTy2(rewriter, loc, elemsPerThread, warpsPerGroupM, waveM,
|
||||
lane, wpt, numOfElems, reps, cSwizzleOffset);
|
||||
@@ -242,17 +243,17 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
Value wave = udiv(thread, waveSize);
|
||||
Value lane = urem(thread, waveSize);
|
||||
|
||||
Value waveN = getWaveN(rewriter, loc, wave, mfmaLayout.getWarpsPerCTA(),
|
||||
Value waveN = getWaveN(rewriter, loc, wave, warpsPerCTA,
|
||||
mfmaInstrN, shape[1]);
|
||||
int numOfElems = std::max<int>(mfmaInstrK * mfmaInstrN / 64 /*wave size*/, 1);
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
|
||||
int macroTileM =
|
||||
std::max<int>(shape[0] / (mfmaLayout.getWarpsPerCTA()[0] * 32), 1);
|
||||
int wptM = std::min<int>(mfmaLayout.getWarpsPerCTA()[0], macroTileM);
|
||||
std::max<int>(shape[0] / (warpsPerCTA[0] * 32), 1);
|
||||
int wptM = std::min<int>(warpsPerCTA[0], macroTileM);
|
||||
int macroTileN =
|
||||
std::max<int>(shape[1] / (mfmaLayout.getWarpsPerCTA()[1] * 32), 1);
|
||||
int wptN = std::min<int>(mfmaLayout.getWarpsPerCTA()[1], macroTileN);
|
||||
std::max<int>(shape[1] / (warpsPerCTA[1] * 32), 1);
|
||||
int wptN = std::min<int>(warpsPerCTA[1], macroTileN);
|
||||
int wpt = std::max<int>(wptM, wptN);
|
||||
|
||||
llvm::SmallVector<Value> offsets;
|
||||
|
||||
@@ -1227,8 +1227,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
[32, 256, 32, 8],
|
||||
]
|
||||
for allow_tf32 in [False, True]
|
||||
for col_a in [False]
|
||||
for col_b in [False]
|
||||
for col_a in [True,False]
|
||||
for col_b in [True,False]
|
||||
for dtype in ['int8', 'float16', 'float32']])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
|
||||
Reference in New Issue
Block a user