mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA] MI200 bfloat16 support (#294)
This PR enables bfloat16 support in MFMA dot on MI200. Used mfma_f32_32x32x8bf16_1k instruction.
This commit is contained in:
@@ -459,6 +459,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(aElemTy);
|
||||
Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy;
|
||||
|
||||
int loadsPerThread = offsets.size() / (numRepM * numRepK);
|
||||
const int elemsPerLoad = numOfElems / loadsPerThread;
|
||||
@@ -466,7 +467,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
|
||||
for (int m = 0; m < numRepM; ++m) {
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto vecTy = vec_ty(aElemTy, numOfElems);
|
||||
auto vecTy = vec_ty(resElemTy, numOfElems);
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
auto loadVecTy = vec_ty(aElemTy, elemsPerLoad);
|
||||
@@ -479,11 +480,13 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
||||
Value elemVal =
|
||||
extract_element(aElemTy, vectorValue, i32_val(elemId));
|
||||
elemVal = bitcast(elemVal, resElemTy);
|
||||
valVec = insert_element(vecTy, valVec, elemVal,
|
||||
i32_val(loadId * elemsPerLoad + elemId));
|
||||
}
|
||||
} else {
|
||||
valVec = extract_element(aElemTy, vectorValue, i32_val(0));
|
||||
valVec = bitcast(valVec, resElemTy);
|
||||
}
|
||||
}
|
||||
if (aElemTy == i8_ty)
|
||||
@@ -497,6 +500,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
numReps, smemObj, sharedLayout);
|
||||
|
||||
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
|
||||
Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy;
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(aElemTy);
|
||||
|
||||
@@ -505,7 +509,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
|
||||
for (int m = 0; m < numRepM; ++m) {
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto vecTy = vec_ty(aElemTy, numOfElems);
|
||||
auto vecTy = vec_ty(resElemTy, numOfElems);
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
auto loadVecTy = vec_ty(aElemTy, elemsPerLoad);
|
||||
@@ -518,11 +522,13 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
||||
Value elemVal =
|
||||
extract_element(aElemTy, vectorValue, i32_val(elemId));
|
||||
elemVal = bitcast(elemVal, resElemTy);
|
||||
valVec = insert_element(vecTy, valVec, elemVal,
|
||||
i32_val(loadId * elemsPerLoad + elemId));
|
||||
}
|
||||
} else {
|
||||
valVec = extract_element(aElemTy, vectorValue, i32_val(0));
|
||||
valVec = bitcast(valVec, resElemTy);
|
||||
}
|
||||
}
|
||||
if (aElemTy == i8_ty)
|
||||
@@ -602,6 +608,8 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
|
||||
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
|
||||
Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy;
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(bElemTy);
|
||||
|
||||
const int loadsPerThread = offsets.size() / (numRepN * numRepK);
|
||||
@@ -610,7 +618,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
|
||||
for (int n = 0; n < numRepN; ++n) {
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto vecTy = vec_ty(bElemTy, numOfElems);
|
||||
auto vecTy = vec_ty(resElemTy, numOfElems);
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
auto loadVecTy = vec_ty(bElemTy, elemsPerLoad);
|
||||
@@ -623,11 +631,13 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
||||
Value elemVal =
|
||||
extract_element(bElemTy, vectorValue, i32_val(elemId));
|
||||
elemVal = bitcast(elemVal, resElemTy);
|
||||
valVec = insert_element(vecTy, valVec, elemVal,
|
||||
i32_val(loadId * elemsPerLoad + elemId));
|
||||
}
|
||||
} else {
|
||||
valVec = extract_element(bElemTy, vectorValue, i32_val(0));
|
||||
valVec = bitcast(valVec, resElemTy);
|
||||
}
|
||||
}
|
||||
if (bElemTy == i8_ty)
|
||||
@@ -642,13 +652,15 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
|
||||
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
|
||||
|
||||
Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy;
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(bElemTy);
|
||||
|
||||
int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]);
|
||||
int elemsPerLoad = numOfElems / loadsPerThread;
|
||||
for (int n = 0; n < numRepN; ++n) {
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto vecTy = vec_ty(bElemTy, numOfElems);
|
||||
auto vecTy = vec_ty(resElemTy, numOfElems);
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
auto loadVecTy = vec_ty(bElemTy, elemsPerLoad);
|
||||
@@ -661,11 +673,13 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
|
||||
Value elemVal =
|
||||
extract_element(bElemTy, vectorValue, i32_val(elemId));
|
||||
elemVal = bitcast(elemVal, resElemTy);
|
||||
valVec = insert_element(vecTy, valVec, elemVal,
|
||||
i32_val(loadId * elemsPerLoad + elemId));
|
||||
}
|
||||
} else {
|
||||
valVec = extract_element(bElemTy, vectorValue, i32_val(0));
|
||||
valVec = bitcast(valVec, resElemTy);
|
||||
}
|
||||
}
|
||||
if (bElemTy == i8_ty)
|
||||
|
||||
@@ -57,7 +57,7 @@ struct DotOpMFMAConversionHelper {
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP32_FP32_FP32:
|
||||
|
||||
@@ -669,7 +669,7 @@ static SmallVector<int64_t> getMFMAInstrShape(Type abElemType) {
|
||||
if (abElemType.isF32())
|
||||
return {32l, 32l, 2l}; // FP32_FP32_FP32_FP32;
|
||||
if (abElemType.isBF16())
|
||||
return {32l, 32l, 4l}; // FP32_BF16_BF16_FP32;
|
||||
return {32l, 32l, 8l}; // FP32_BF16_BF16_FP32;
|
||||
if (abElemType.isInteger(8))
|
||||
return {32l, 32l, 8l}; // INT32_INT8_INT8_INT32;
|
||||
if (abElemType.isF64())
|
||||
|
||||
@@ -1233,6 +1233,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax']
|
||||
for allow_tf32 in [True, False]
|
||||
for in_dtype, out_dtype in [('float16', 'float16'),
|
||||
('bfloat16', 'float32'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]
|
||||
if not (allow_tf32 and (in_dtype in ['float16']))] +
|
||||
@@ -1264,7 +1265,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
for allow_tf32 in [False, True]
|
||||
for col_a in [True, False]
|
||||
for col_b in [True, False]
|
||||
for in_dtype in ['int8', 'float16', 'float32']
|
||||
for in_dtype in ['int8', 'bfloat16', 'float16', 'float32']
|
||||
for out_dtype in [None]])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
@@ -1354,6 +1355,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
if in_dtype == 'bfloat16':
|
||||
x_tri = x_tri.to(torch.bfloat16)
|
||||
y_tri = y_tri.to(torch.bfloat16)
|
||||
w_tri = w_tri.to(torch.bfloat16)
|
||||
# triton result
|
||||
if out_dtype == 'int8':
|
||||
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
|
||||
@@ -1414,6 +1419,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
elif out_dtype == tl.float16:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
|
||||
elif in_dtype == 'bfloat16':
|
||||
# added atol, to loose precision for bfloat16xbfloat16->float32 case
|
||||
# bfloat16 has less fraction bits than float16
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
|
||||
else:
|
||||
# added atol, to loose precision for float16xfloat16->float32 case
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user