mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Dot] [MFMA] [FMA] Update Dot implementation to support upstream tests (#260)
* [Dot] [MFMA] Support FP16 output of MFMA dot This PR adds cast of output tensor to requested data type. * add tests * fix test for FMA implementation * loose fp16xfp16->fp16 tolerance * enable FMA fallback for unsupported sizes of dot operation * rework granularity check * add constant modifier to granularity
This commit is contained in:
@@ -149,13 +149,34 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool supportMFMAGranularity(int m, int n, int k) {
|
||||
// these limitations are dtype dependent, in future we may relax them
|
||||
const int granularityMN = 32;
|
||||
const int granularityK = 8;
|
||||
if (m % granularityMN != 0 || n % granularityMN != 0)
|
||||
return false;
|
||||
if (k % granularityK != 0)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool supportMFMA(triton::DotOp op) {
|
||||
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
|
||||
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
|
||||
auto aTy = op.getA().getType().cast<RankedTensorType>();
|
||||
auto bTy = op.getB().getType().cast<RankedTensorType>();
|
||||
|
||||
auto aElemTy = aTy.getElementType();
|
||||
auto bElemTy = bTy.getElementType();
|
||||
|
||||
if (aElemTy != bElemTy)
|
||||
return false;
|
||||
|
||||
auto aShape = aTy.getShape();
|
||||
auto bShape = bTy.getShape();
|
||||
|
||||
assert(aShape[1] == bShape[0]);
|
||||
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
|
||||
return false;
|
||||
|
||||
return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||
|
||||
aElemTy.isInteger(8);
|
||||
}
|
||||
|
||||
@@ -1227,16 +1227,19 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
|
||||
|
||||
# MFMA Test Dot tests
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
|
||||
[(*shape, 2, False, False, epilogue, allow_tf32, dtype)
|
||||
for shape in [(64, 64, 64), (32, 32, 32)]
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype",
|
||||
[(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype)
|
||||
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16', 'float32']
|
||||
if not (allow_tf32 and (dtype in ['float16']))] +
|
||||
for in_dtype, out_dtype in [('float16', 'float16'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]
|
||||
if not (allow_tf32 and (in_dtype in ['float16']))] +
|
||||
|
||||
[(*shape_nw, col_a, col_b, 'none', allow_tf32, dtype)
|
||||
[(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype)
|
||||
for shape_nw in [[128, 128, 32, 2],
|
||||
[128, 16, 32, 4],
|
||||
[128, 128, 64, 2],
|
||||
[128, 32, 32, 2],
|
||||
[128, 32, 64, 2],
|
||||
@@ -1261,21 +1264,27 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
for allow_tf32 in [False, True]
|
||||
for col_a in [True, False]
|
||||
for col_b in [True, False]
|
||||
for dtype in ['int8', 'float16', 'float32']])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
|
||||
for in_dtype in ['int8', 'float16', 'float32']
|
||||
for out_dtype in [None]])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
|
||||
if torch.version.hip is not None:
|
||||
# set capability to large number to jump over check below
|
||||
# check are not relevant to amd gpu, left them for smaller diff between test_core.py and test_core_amd.py tests
|
||||
capability = (100, 100)
|
||||
if out_dtype is None:
|
||||
if in_dtype in float_dtypes:
|
||||
out_dtype = "float32"
|
||||
else:
|
||||
out_dtype = "int32"
|
||||
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if capability[0] < 8:
|
||||
if dtype == 'int8':
|
||||
if in_dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
elif in_dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
if capability[0] == 7:
|
||||
if (M, N, K, num_warps) == (128, 256, 32, 8):
|
||||
@@ -1289,6 +1298,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
Y, stride_yk, stride_yn,
|
||||
W, stride_wn, stride_wl,
|
||||
Z, stride_zm, stride_zn,
|
||||
out_dtype: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr,
|
||||
@@ -1304,7 +1314,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
x = tl.load(Xs)
|
||||
y = tl.load(Ys)
|
||||
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
|
||||
z = tl.dot(x, y, allow_tf32=ALLOW_TF32, out_dtype=out_dtype)
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
if ADD_ROWS:
|
||||
@@ -1316,28 +1326,28 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
if DO_SOFTMAX:
|
||||
max = tl.max(z, 1)
|
||||
z = z - max[:, None]
|
||||
num = tl.exp(z)
|
||||
num = tl.exp(z.to(tl.float32)).to(max.dtype)
|
||||
den = tl.sum(num, 1)
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
w = tl.load(Ws)
|
||||
z = tl.dot(z.to(w.dtype), w)
|
||||
z = tl.dot(z.to(w.dtype), w, out_dtype=out_dtype)
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
if col_a:
|
||||
x = numpy_random((K, M), dtype_str=dtype, rs=rs).T
|
||||
x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T
|
||||
else:
|
||||
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
|
||||
x = numpy_random((M, K), dtype_str=in_dtype, rs=rs)
|
||||
if col_b:
|
||||
y = numpy_random((N, K), dtype_str=dtype, rs=rs).T
|
||||
y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T
|
||||
else:
|
||||
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
|
||||
w = numpy_random((N, N), dtype_str=dtype, rs=rs)
|
||||
if 'int' not in dtype:
|
||||
y = numpy_random((K, N), dtype_str=in_dtype, rs=rs)
|
||||
w = numpy_random((N, N), dtype_str=in_dtype, rs=rs)
|
||||
if 'int' not in in_dtype:
|
||||
x *= .1
|
||||
y *= .1
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
if in_dtype == 'float32' and allow_tf32:
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
@@ -1345,18 +1355,30 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
# triton result
|
||||
if dtype == 'int8':
|
||||
if out_dtype == 'int8':
|
||||
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
|
||||
else:
|
||||
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1
|
||||
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
|
||||
if out_dtype == 'int8':
|
||||
out_dtype = tl.int8
|
||||
elif out_dtype == 'float16' and epilogue != 'softmax':
|
||||
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
|
||||
# fail with the following error: 'llvm.fmul' op requires the same type
|
||||
# for all operands and results
|
||||
out_dtype = tl.float16
|
||||
else:
|
||||
out_dtype = tl.float32
|
||||
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
out_dtype,
|
||||
COL_A=col_a, COL_B=col_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
@@ -1367,7 +1389,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
# torch result
|
||||
if dtype == 'int8':
|
||||
if in_dtype == 'int8':
|
||||
z_ref = np.matmul(x.astype(np.float32),
|
||||
y.astype(np.float32())).astype(np.int32)
|
||||
else:
|
||||
@@ -1387,24 +1409,30 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
if dtype == 'float32' or dtype == 'float16':
|
||||
if in_dtype == 'float32':
|
||||
# XXX: Somehow there's a larger difference when we use float32
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
elif out_dtype == tl.float16:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
if torch.version.hip is None:
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
if K > 16 or N > 16 or M > 16:
|
||||
# XXX: skip small sizes because they are not vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
# added atol, to loose precision for float16xfloat16->float32 case
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
if torch.version.hip is not None:
|
||||
return
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
|
||||
# XXX: skip small sizes because they are not vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if in_dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif in_dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif in_dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
elif out_dtype == tl.float16:
|
||||
assert 'mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16' in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
|
||||
@@ -2348,7 +2376,7 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("shape", [(64, 64)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], isTransposed=False), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed = True)])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], isTransposed=False), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed=True)])
|
||||
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0])])
|
||||
def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
ir = f"""
|
||||
|
||||
@@ -1238,9 +1238,20 @@ def gpu_has_mfma() -> bool:
|
||||
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
|
||||
return gfx_arch_details[1].split(':')[0] in ['gfx908', 'gfx90a', 'gfx940', 'gfx941']
|
||||
|
||||
def mfma_supported_granularity(m, n, k) -> bool:
|
||||
granularity_mn = 32
|
||||
granularity_k = 8
|
||||
if m % granularity_mn != 0 or n % granularity_mn != 0:
|
||||
return False
|
||||
if k % granularity_k != 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
if not gpu_has_mfma():
|
||||
return False
|
||||
if not mfma_supported_granularity(M, N ,K):
|
||||
return False
|
||||
# TODO: Add check for configurations and types.
|
||||
return True
|
||||
|
||||
@@ -1278,11 +1289,25 @@ def dot(lhs: tl.tensor,
|
||||
ret_cast_scalar_ty = tl.float32 if lhs.type.scalar.is_int() else ret_scalar_ty
|
||||
lhs = cast(lhs, ret_cast_scalar_ty, builder)
|
||||
rhs = cast(rhs, ret_cast_scalar_ty, builder)
|
||||
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
|
||||
if ret_cast_scalar_ty == tl.float16:
|
||||
_0 = builder.create_splat(builder.get_fp16(0), [M, N])
|
||||
else:
|
||||
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
|
||||
ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N])
|
||||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
return cast(ret, ret_scalar_ty, builder)
|
||||
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
|
||||
if lhs.type.scalar.is_int():
|
||||
ret_dot_scalar_ty = tl.int32
|
||||
_0 = builder.create_splat(builder.get_int32(0), [M, N])
|
||||
else:
|
||||
ret_dot_scalar_ty = tl.float32
|
||||
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
|
||||
ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N])
|
||||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
return cast(ret, ret_scalar_ty, builder)
|
||||
|
||||
_0 = builder.create_splat(_0, [M, N])
|
||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||
|
||||
Reference in New Issue
Block a user