mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Add support for FP16 output for tl.dot (#1258)
--------- Co-authored-by: Fei Hu <fhu@microsoft.com>
This commit is contained in:
@@ -1173,15 +1173,17 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
|
||||
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype",
|
||||
[(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype)
|
||||
for shape in [(64, 64, 64), (16, 16, 16)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16', 'float32']
|
||||
if not (allow_tf32 and (dtype in ['float16']))] +
|
||||
for in_dtype, out_dtype in [('float16', 'float16'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]
|
||||
if not (allow_tf32 and (in_dtype in ['float16']))] +
|
||||
|
||||
[(*shape_nw, col_a, col_b, 'none', allow_tf32, dtype)
|
||||
[(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype)
|
||||
for shape_nw in [[128, 256, 32, 8],
|
||||
[128, 16, 32, 4],
|
||||
[32, 128, 64, 4],
|
||||
@@ -1194,19 +1196,25 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
for allow_tf32 in [True]
|
||||
for col_a in [True, False]
|
||||
for col_b in [True, False]
|
||||
for dtype in ['int8', 'float16', 'float32']])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
|
||||
for in_dtype, out_dtype in [('int8', 'int8'),
|
||||
('float16', 'float16'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if capability[0] < 8:
|
||||
if dtype == 'int8':
|
||||
if in_dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
elif in_dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
if capability[0] == 7:
|
||||
if (M, N, K, num_warps) == (128, 256, 32, 8):
|
||||
pytest.skip("shared memory out of resource")
|
||||
if out_dtype == 'float16':
|
||||
# TODO: support out_dtype=float16 for tl.dot on V100
|
||||
pytest.skip("Only test out_dtype=float16 on devices with sm >=80")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
@@ -1216,6 +1224,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
Y, stride_yk, stride_yn,
|
||||
W, stride_wn, stride_wl,
|
||||
Z, stride_zm, stride_zn,
|
||||
out_dtype: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr,
|
||||
@@ -1231,7 +1240,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
x = tl.load(Xs)
|
||||
y = tl.load(Ys)
|
||||
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
|
||||
z = tl.dot(x, y, allow_tf32=ALLOW_TF32, out_dtype=out_dtype)
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
if ADD_ROWS:
|
||||
@@ -1248,23 +1257,23 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
w = tl.load(Ws)
|
||||
z = tl.dot(z.to(w.dtype), w)
|
||||
z = tl.dot(z.to(w.dtype), w, out_dtype=out_dtype)
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
if col_a:
|
||||
x = numpy_random((K, M), dtype_str=dtype, rs=rs).T
|
||||
x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T
|
||||
else:
|
||||
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
|
||||
x = numpy_random((M, K), dtype_str=in_dtype, rs=rs)
|
||||
if col_b:
|
||||
y = numpy_random((N, K), dtype_str=dtype, rs=rs).T
|
||||
y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T
|
||||
else:
|
||||
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
|
||||
w = numpy_random((N, N), dtype_str=dtype, rs=rs)
|
||||
if 'int' not in dtype:
|
||||
y = numpy_random((K, N), dtype_str=in_dtype, rs=rs)
|
||||
w = numpy_random((N, N), dtype_str=in_dtype, rs=rs)
|
||||
if 'int' not in in_dtype:
|
||||
x *= .1
|
||||
y *= .1
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
if in_dtype == 'float32' and allow_tf32:
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
@@ -1272,18 +1281,30 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
# triton result
|
||||
if dtype == 'int8':
|
||||
if out_dtype == 'int8':
|
||||
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
|
||||
else:
|
||||
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1
|
||||
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
|
||||
if out_dtype == 'int8':
|
||||
out_dtype = tl.int8
|
||||
elif out_dtype == 'float16' and epilogue != 'softmax':
|
||||
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
|
||||
# fail with the following error: 'llvm.fmul' op requires the same type
|
||||
# for all operands and results
|
||||
out_dtype = tl.float16
|
||||
else:
|
||||
out_dtype = tl.float32
|
||||
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
out_dtype,
|
||||
COL_A=col_a, COL_B=col_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
@@ -1294,7 +1315,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
# torch result
|
||||
if dtype == 'int8':
|
||||
if in_dtype == 'int8':
|
||||
z_ref = np.matmul(x.astype(np.float32),
|
||||
y.astype(np.float32())).astype(np.int32)
|
||||
else:
|
||||
@@ -1314,9 +1335,11 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
if dtype == 'float32':
|
||||
if in_dtype == 'float32':
|
||||
# XXX: Somehow there's a larger difference when we use float32
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
elif out_dtype == tl.float16:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
@@ -1325,12 +1348,14 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
# XXX: skip small sizes because they are not vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
if in_dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
elif in_dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
elif in_dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
elif out_dtype == tl.float16:
|
||||
assert 'mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16' in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
|
||||
@@ -1467,7 +1492,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
|
||||
|
||||
# Without a dot product the memory doesn't get promoted to shared.
|
||||
o = tl.dot(x, w)
|
||||
o = tl.dot(x, w, out_dtype=tl.float32)
|
||||
|
||||
# Store output
|
||||
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
|
||||
|
||||
Reference in New Issue
Block a user