Merge branch 'triton-mlir' into ifu230601

This commit is contained in:
jayfurmanek
2023-06-01 16:18:25 -05:00
committed by GitHub
2 changed files with 13 additions and 12 deletions

View File

@@ -166,18 +166,19 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
// TODO make macro tile size granilarity configurable
int macroTileM =
std::max<int>(shape[0] / (mfmaLayout.getWarpsPerCTA()[0] * 32), 1);
int wptM = std::min<int>(mfmaLayout.getWarpsPerCTA()[0], macroTileM);
std::max<int>(shape[0] / (warpsPerCTA[0] * 32), 1);
int wptM = std::min<int>(warpsPerCTA[0], macroTileM);
int macroTileN =
std::max<int>(shape[1] / (mfmaLayout.getWarpsPerCTA()[1] * 32), 1);
int wptN = std::min<int>(mfmaLayout.getWarpsPerCTA()[1], macroTileN);
std::max<int>(shape[1] / (warpsPerCTA[1] * 32), 1);
int wptN = std::min<int>(warpsPerCTA[1], macroTileN);
int wpt = std::max<int>(wptM, wptN);
SmallVector<Value> offsets;
if (isTransposed(order)) {
SmallVector<int64_t> elemsPerThread{aElemsPerThread[1], aElemsPerThread[0]};
SmallVector<int64_t> reps{numReps[1], numReps[0]};
int warpsPerGroupM = mfmaLayout.getWarpsPerCTA()[0];
unsigned int maxNumWarps = shape[0] / mfmaInstrM;
int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps);
offsets =
computeOffsetsTy2(rewriter, loc, elemsPerThread, warpsPerGroupM, waveM,
lane, wpt, numOfElems, reps, cSwizzleOffset);
@@ -242,17 +243,17 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
Value wave = udiv(thread, waveSize);
Value lane = urem(thread, waveSize);
Value waveN = getWaveN(rewriter, loc, wave, mfmaLayout.getWarpsPerCTA(),
Value waveN = getWaveN(rewriter, loc, wave, warpsPerCTA,
mfmaInstrN, shape[1]);
int numOfElems = std::max<int>(mfmaInstrK * mfmaInstrN / 64 /*wave size*/, 1);
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
int macroTileM =
std::max<int>(shape[0] / (mfmaLayout.getWarpsPerCTA()[0] * 32), 1);
int wptM = std::min<int>(mfmaLayout.getWarpsPerCTA()[0], macroTileM);
std::max<int>(shape[0] / (warpsPerCTA[0] * 32), 1);
int wptM = std::min<int>(warpsPerCTA[0], macroTileM);
int macroTileN =
std::max<int>(shape[1] / (mfmaLayout.getWarpsPerCTA()[1] * 32), 1);
int wptN = std::min<int>(mfmaLayout.getWarpsPerCTA()[1], macroTileN);
std::max<int>(shape[1] / (warpsPerCTA[1] * 32), 1);
int wptN = std::min<int>(warpsPerCTA[1], macroTileN);
int wpt = std::max<int>(wptM, wptN);
llvm::SmallVector<Value> offsets;

View File

@@ -1227,8 +1227,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
[32, 256, 32, 8],
]
for allow_tf32 in [False, True]
for col_a in [False]
for col_b in [False]
for col_a in [True,False]
for col_b in [True,False]
for dtype in ['int8', 'float16', 'float32']])
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
capability = torch.cuda.get_device_capability()