[Dot] [MFMA] [FMA] Update Dot implementation to support upstream tests (#260)

* [Dot] [MFMA] Support FP16 output of MFMA dot

This PR adds cast of output tensor to requested data type.

* add tests

* fix test for FMA implementation

* loose fp16xfp16->fp16 tolerance

* enable FMA fallback for unsupported sizes of dot operation

* rework granularity check

* add constant modifier to granularity
This commit is contained in:
Alexander Efimov
2023-08-03 20:47:18 +02:00
committed by GitHub
parent f1063bb33c
commit 86f8b64ae0
3 changed files with 116 additions and 42 deletions

View File

@@ -149,13 +149,34 @@ bool supportMMA(triton::DotOp op, int version) {
}
#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
const int granularityMN = 32;
const int granularityK = 8;
if (m % granularityMN != 0 || n % granularityMN != 0)
return false;
if (k % granularityK != 0)
return false;
return true;
}
bool supportMFMA(triton::DotOp op) {
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
auto aTy = op.getA().getType().cast<RankedTensorType>();
auto bTy = op.getB().getType().cast<RankedTensorType>();
auto aElemTy = aTy.getElementType();
auto bElemTy = bTy.getElementType();
if (aElemTy != bElemTy)
return false;
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();
assert(aShape[1] == bShape[0]);
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
return false;
return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||
aElemTy.isInteger(8);
}

View File

@@ -1227,16 +1227,19 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# MFMA Test Dot tests
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
[(*shape, 2, False, False, epilogue, allow_tf32, dtype)
for shape in [(64, 64, 64), (32, 32, 32)]
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype",
[(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype)
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax']
for allow_tf32 in [True, False]
for dtype in ['float16', 'float32']
if not (allow_tf32 and (dtype in ['float16']))] +
for in_dtype, out_dtype in [('float16', 'float16'),
('float16', 'float32'),
('float32', 'float32')]
if not (allow_tf32 and (in_dtype in ['float16']))] +
[(*shape_nw, col_a, col_b, 'none', allow_tf32, dtype)
[(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype)
for shape_nw in [[128, 128, 32, 2],
[128, 16, 32, 4],
[128, 128, 64, 2],
[128, 32, 32, 2],
[128, 32, 64, 2],
@@ -1261,21 +1264,27 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
for allow_tf32 in [False, True]
for col_a in [True, False]
for col_b in [True, False]
for dtype in ['int8', 'float16', 'float32']])
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
for in_dtype in ['int8', 'float16', 'float32']
for out_dtype in [None]])
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, device='cuda'):
capability = torch.cuda.get_device_capability()
if torch.version.hip is not None:
# set capability to large number to jump over check below
# check are not relevant to amd gpu, left them for smaller diff between test_core.py and test_core_amd.py tests
capability = (100, 100)
if out_dtype is None:
if in_dtype in float_dtypes:
out_dtype = "float32"
else:
out_dtype = "int32"
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if capability[0] < 8:
if dtype == 'int8':
if in_dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
elif in_dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
if capability[0] == 7:
if (M, N, K, num_warps) == (128, 256, 32, 8):
@@ -1289,6 +1298,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
Y, stride_yk, stride_yn,
W, stride_wn, stride_wl,
Z, stride_zm, stride_zn,
out_dtype: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr,
@@ -1304,7 +1314,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
x = tl.load(Xs)
y = tl.load(Ys)
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
z = tl.dot(x, y, allow_tf32=ALLOW_TF32, out_dtype=out_dtype)
if ADD_MATRIX:
z += tl.load(Zs)
if ADD_ROWS:
@@ -1316,28 +1326,28 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
if DO_SOFTMAX:
max = tl.max(z, 1)
z = z - max[:, None]
num = tl.exp(z)
num = tl.exp(z.to(tl.float32)).to(max.dtype)
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
w = tl.load(Ws)
z = tl.dot(z.to(w.dtype), w)
z = tl.dot(z.to(w.dtype), w, out_dtype=out_dtype)
tl.store(Zs, z)
# input
rs = RandomState(17)
if col_a:
x = numpy_random((K, M), dtype_str=dtype, rs=rs).T
x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T
else:
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
x = numpy_random((M, K), dtype_str=in_dtype, rs=rs)
if col_b:
y = numpy_random((N, K), dtype_str=dtype, rs=rs).T
y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T
else:
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
w = numpy_random((N, N), dtype_str=dtype, rs=rs)
if 'int' not in dtype:
y = numpy_random((K, N), dtype_str=in_dtype, rs=rs)
w = numpy_random((N, N), dtype_str=in_dtype, rs=rs)
if 'int' not in in_dtype:
x *= .1
y *= .1
if dtype == 'float32' and allow_tf32:
if in_dtype == 'float32' and allow_tf32:
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
@@ -1345,18 +1355,30 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
# triton result
if dtype == 'int8':
if out_dtype == 'int8':
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
else:
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
if out_dtype == 'int8':
out_dtype = tl.int8
elif out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
else:
out_dtype = tl.float32
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
out_dtype,
COL_A=col_a, COL_B=col_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
@@ -1367,7 +1389,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
# torch result
if dtype == 'int8':
if in_dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),
y.astype(np.float32())).astype(np.int32)
else:
@@ -1387,24 +1409,30 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
if dtype == 'float32' or dtype == 'float16':
if in_dtype == 'float32':
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
elif out_dtype == tl.float16:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
if torch.version.hip is None:
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
if K > 16 or N > 16 or M > 16:
# XXX: skip small sizes because they are not vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# added atol, to loose precision for float16xfloat16->float32 case
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
if torch.version.hip is not None:
return
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
# XXX: skip small sizes because they are not vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if in_dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif in_dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif in_dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
elif out_dtype == tl.float16:
assert 'mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16' in ptx
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
@@ -2348,7 +2376,7 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
@pytest.mark.parametrize("shape", [(64, 64)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], isTransposed=False), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed = True)])
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], isTransposed=False), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed=True)])
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0])])
def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'):
ir = f"""

View File

@@ -1238,9 +1238,20 @@ def gpu_has_mfma() -> bool:
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
return gfx_arch_details[1].split(':')[0] in ['gfx908', 'gfx90a', 'gfx940', 'gfx941']
def mfma_supported_granularity(m, n, k) -> bool:
granularity_mn = 32
granularity_k = 8
if m % granularity_mn != 0 or n % granularity_mn != 0:
return False
if k % granularity_k != 0:
return False
return True
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
if not gpu_has_mfma():
return False
if not mfma_supported_granularity(M, N ,K):
return False
# TODO: Add check for configurations and types.
return True
@@ -1278,11 +1289,25 @@ def dot(lhs: tl.tensor,
ret_cast_scalar_ty = tl.float32 if lhs.type.scalar.is_int() else ret_scalar_ty
lhs = cast(lhs, ret_cast_scalar_ty, builder)
rhs = cast(rhs, ret_cast_scalar_ty, builder)
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
if ret_cast_scalar_ty == tl.float16:
_0 = builder.create_splat(builder.get_fp16(0), [M, N])
else:
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N])
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
return cast(ret, ret_scalar_ty, builder)
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
if lhs.type.scalar.is_int():
ret_dot_scalar_ty = tl.int32
_0 = builder.create_splat(builder.get_int32(0), [M, N])
else:
ret_dot_scalar_ty = tl.float32
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N])
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
return cast(ret, ret_scalar_ty, builder)
_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])