mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTING] clean up testing.do_bench (#1513)
This commit is contained in:
@@ -67,7 +67,7 @@ matmul_data = {
|
||||
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
|
||||
(64, 4096, 4096): {'float16': 0.16, 'float32': 0.162, 'int8': 0.097},
|
||||
(64, 8192, 8192): {'float16': 0.30, 'float32': 0.257, 'int8': 0.174},
|
||||
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
|
||||
(1024, 64, 1024): {'float16': 0.037, 'float32': 0.0458, 'int8': 0.017},
|
||||
(4096, 64, 4096): {'float16': 0.16, 'float32': 0.177, 'int8': 0.102},
|
||||
(8192, 64, 8192): {'float16': 0.25, 'float32': 0.230, 'int8': 0.177},
|
||||
}
|
||||
@@ -94,10 +94,10 @@ def test_matmul(M, N, K, dtype_str):
|
||||
a = torch.randn((M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randn((K, N), dtype=dtype, device='cuda')
|
||||
fn = lambda: triton.ops.matmul(a, b)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
|
||||
ms = triton.testing.do_bench(fn, warmup=100, rep=300)
|
||||
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
|
||||
#######################
|
||||
@@ -131,8 +131,8 @@ elementwise_data = {
|
||||
'a100': {
|
||||
1024 * 16: 0.008,
|
||||
1024 * 64: 0.034,
|
||||
1024 * 256: 0.114,
|
||||
1024 * 1024: 0.315,
|
||||
1024 * 256: 0.132,
|
||||
1024 * 1024: 0.352,
|
||||
1024 * 4096: 0.580,
|
||||
1024 * 16384: 0.782,
|
||||
1024 * 65536: 0.850,
|
||||
@@ -150,10 +150,10 @@ def test_elementwise(N):
|
||||
y = torch.randn_like(z)
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
|
||||
ms = triton.testing.do_bench(fn, warmup=100, rep=500)
|
||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
#######################
|
||||
# Flash-Attention
|
||||
@@ -189,7 +189,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
|
||||
ms = triton.testing.do_bench(fn, warmup=100, rep=500)
|
||||
# compute flops
|
||||
flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5
|
||||
total_flops = 2 * flops_per_matmul
|
||||
@@ -201,4 +201,4 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
Reference in New Issue
Block a user