[MFMA] MI200 bfloat16 support (#294)

This PR enables bfloat16 support in MFMA dot on MI200.
Used mfma_f32_32x32x8bf16_1k instruction.
This commit is contained in:
Alexander Efimov
2023-08-18 14:28:18 +02:00
committed by GitHub
parent f7cf2c032b
commit 23979098c8
4 changed files with 30 additions and 7 deletions

View File

@@ -459,6 +459,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
Type smemPtrTy = getShemPtrTy(aElemTy);
Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy;
int loadsPerThread = offsets.size() / (numRepM * numRepK);
const int elemsPerLoad = numOfElems / loadsPerThread;
@@ -466,7 +467,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
for (int m = 0; m < numRepM; ++m) {
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(aElemTy, numOfElems);
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(aElemTy, elemsPerLoad);
@@ -479,11 +480,13 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(aElemTy, vectorValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = extract_element(aElemTy, vectorValue, i32_val(0));
valVec = bitcast(valVec, resElemTy);
}
}
if (aElemTy == i8_ty)
@@ -497,6 +500,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
numReps, smemObj, sharedLayout);
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy;
Type smemPtrTy = getShemPtrTy(aElemTy);
@@ -505,7 +509,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
for (int m = 0; m < numRepM; ++m) {
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(aElemTy, numOfElems);
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(aElemTy, elemsPerLoad);
@@ -518,11 +522,13 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(aElemTy, vectorValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = extract_element(aElemTy, vectorValue, i32_val(0));
valVec = bitcast(valVec, resElemTy);
}
}
if (aElemTy == i8_ty)
@@ -602,6 +608,8 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy;
Type smemPtrTy = getShemPtrTy(bElemTy);
const int loadsPerThread = offsets.size() / (numRepN * numRepK);
@@ -610,7 +618,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
for (int n = 0; n < numRepN; ++n) {
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(bElemTy, numOfElems);
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(bElemTy, elemsPerLoad);
@@ -623,11 +631,13 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(bElemTy, vectorValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = extract_element(bElemTy, vectorValue, i32_val(0));
valVec = bitcast(valVec, resElemTy);
}
}
if (bElemTy == i8_ty)
@@ -642,13 +652,15 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy;
Type smemPtrTy = getShemPtrTy(bElemTy);
int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]);
int elemsPerLoad = numOfElems / loadsPerThread;
for (int n = 0; n < numRepN; ++n) {
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(bElemTy, numOfElems);
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(bElemTy, elemsPerLoad);
@@ -661,11 +673,13 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(bElemTy, vectorValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = extract_element(bElemTy, vectorValue, i32_val(0));
valVec = bitcast(valVec, resElemTy);
}
}
if (bElemTy == i8_ty)

View File

@@ -57,7 +57,7 @@ struct DotOpMFMAConversionHelper {
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:

View File

@@ -669,7 +669,7 @@ static SmallVector<int64_t> getMFMAInstrShape(Type abElemType) {
if (abElemType.isF32())
return {32l, 32l, 2l}; // FP32_FP32_FP32_FP32;
if (abElemType.isBF16())
return {32l, 32l, 4l}; // FP32_BF16_BF16_FP32;
return {32l, 32l, 8l}; // FP32_BF16_BF16_FP32;
if (abElemType.isInteger(8))
return {32l, 32l, 8l}; // INT32_INT8_INT8_INT32;
if (abElemType.isF64())

View File

@@ -1233,6 +1233,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax']
for allow_tf32 in [True, False]
for in_dtype, out_dtype in [('float16', 'float16'),
('bfloat16', 'float32'),
('float16', 'float32'),
('float32', 'float32')]
if not (allow_tf32 and (in_dtype in ['float16']))] +
@@ -1264,7 +1265,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
for allow_tf32 in [False, True]
for col_a in [True, False]
for col_b in [True, False]
for in_dtype in ['int8', 'float16', 'float32']
for in_dtype in ['int8', 'bfloat16', 'float16', 'float32']
for out_dtype in [None]])
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, device='cuda'):
capability = torch.cuda.get_device_capability()
@@ -1354,6 +1355,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
if in_dtype == 'bfloat16':
x_tri = x_tri.to(torch.bfloat16)
y_tri = y_tri.to(torch.bfloat16)
w_tri = w_tri.to(torch.bfloat16)
# triton result
if out_dtype == 'int8':
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
@@ -1414,6 +1419,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
elif out_dtype == tl.float16:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
elif in_dtype == 'bfloat16':
# added atol, to loose precision for bfloat16xbfloat16->float32 case
# bfloat16 has less fraction bits than float16
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
else:
# added atol, to loose precision for float16xfloat16->float32 case
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)