mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Add flag to control accumulation for fp8 (#2300)
Change the dot to allow taking an initial accumulator and add a flag that will allow the compiler to accumulate in a lower precision than the output type. On Hopper this flag is on by default which allows accumualting with lower precision. This only affect Hopper fp8 dot.
This commit is contained in:
@@ -131,8 +131,8 @@ def check_type_supported(dtype, device):
|
||||
cc = torch.cuda.get_device_capability()
|
||||
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4"):
|
||||
pytest.skip("float8e4 is only supported on NVGPU with cc >= 90")
|
||||
if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4nv"):
|
||||
pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90")
|
||||
|
||||
|
||||
class MmaLayout:
|
||||
@@ -3750,3 +3750,86 @@ def test_ptx_cast(dtype_str, device):
|
||||
buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype)
|
||||
kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2)
|
||||
assert buf14.to(torch.float32).mean() == -2.0
|
||||
|
||||
# -----------------------
|
||||
# test fp8 -> fp32 dot
|
||||
# -----------------------
|
||||
|
||||
|
||||
def f8_to_f16(x, dtype):
|
||||
@triton.jit
|
||||
def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offs < N
|
||||
x = tl.load(X + offs, mask=mask)
|
||||
tl.store(Y + offs, x, mask=mask)
|
||||
|
||||
ret = torch.empty(x.shape, dtype=torch.float16, device=x.device)
|
||||
grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),)
|
||||
dtype = getattr(tl, dtype)
|
||||
kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
low_precision_acc: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
pid_m = pid % num_pid_m
|
||||
pid_n = pid // num_pid_m
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
tl.store(c_ptrs, accumulator)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv'])
|
||||
@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128])
|
||||
def test_fp8_dot_acc(in_type_str, low_precision_acc, device):
|
||||
check_type_supported(in_type_str, device)
|
||||
M, N, K = 128, 256, 256
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128
|
||||
A = numpy_random((M, K), dtype_str=in_type_str)
|
||||
B = numpy_random((K, N), dtype_str=in_type_str)
|
||||
Bt = B.T
|
||||
C = torch.empty((M, N), dtype=torch.float32, device='cuda')
|
||||
num_warps = 8
|
||||
a = to_triton(A, device='cuda', dst_type=in_type_str)
|
||||
b = to_triton(B, device='cuda', dst_type=in_type_str)
|
||||
grid = (triton.cdiv(M, BLOCK_M), 1)
|
||||
matmul_kernel[grid](a, b, C, M, N, K,
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(
|
||||
1), C.stride(0), C.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps)
|
||||
torch_a = torch.from_numpy(A)
|
||||
th_a = f8_to_f16(torch_a.cuda(), in_type_str)
|
||||
torch_b = torch.from_numpy(B)
|
||||
th_b = f8_to_f16(torch_b.cuda(), in_type_str)
|
||||
ref_out = torch.matmul(th_a, th_b).to(torch.float32)
|
||||
if in_type_str == 'float8e4nv':
|
||||
torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01)
|
||||
elif low_precision_acc > 32:
|
||||
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
torch.testing.assert_close(ref_out, C)
|
||||
|
||||
@@ -26,61 +26,61 @@ def f8_to_f16(x, dtype):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32",
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True),
|
||||
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4]
|
||||
],
|
||||
# mixed-precision
|
||||
*[
|
||||
[
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True),
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
|
||||
] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"),
|
||||
("float8e4nv", "float8e4nv"),
|
||||
("float8e5", "float8e4nv"),
|
||||
@@ -91,14 +91,14 @@ def f8_to_f16(x, dtype):
|
||||
("float16", "float32"),
|
||||
("float32", "float16"),
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]
|
||||
],
|
||||
# mixed-precision block layout
|
||||
*[
|
||||
[
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False),
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True),
|
||||
] for ADTYPE, BDTYPE in [("float8e4nv", "float16"),
|
||||
("float16", "float8e5"),
|
||||
("float16", "float32"),
|
||||
@@ -108,7 +108,7 @@ def f8_to_f16(x, dtype):
|
||||
],
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32):
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
@@ -176,7 +176,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
a = triton.reinterpret(a, getattr(tl, ADTYPE))
|
||||
if b_fp8:
|
||||
b = triton.reinterpret(b, getattr(tl, BDTYPE))
|
||||
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32)
|
||||
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32, F8_FASTACCUM)
|
||||
torch.testing.assert_close(th_c, tt_c)
|
||||
except triton.OutOfResources as e:
|
||||
pytest.skip(str(e))
|
||||
|
||||
Reference in New Issue
Block a user