[TESTING] now using cuda graphs for perf regression tests (#1925)

This commit is contained in:
Philippe Tillet
2023-07-10 22:49:25 -07:00
committed by GitHub
parent 4a20d5010b
commit 7e3ebbc4c8
3 changed files with 114 additions and 91 deletions

View File

@@ -38,56 +38,24 @@ sm_clocks = {'v100': 1350, 'a100': 1350}
mem_clocks = {'v100': 877, 'a100': 1215}
matmul_data = {
'v100': {
# square
(512, 512, 512): {'float16': 0.158},
(1024, 1024, 1024): {'float16': 0.466},
(2048, 2048, 2048): {'float16': 0.695},
(4096, 4096, 4096): {'float16': 0.831},
(8192, 8192, 8192): {'float16': 0.849},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0128},
(16, 4096, 4096): {'float16': 0.0883},
(16, 8192, 8192): {'float16': 0.101},
(64, 1024, 1024): {'float16': 0.073},
(64, 4096, 4096): {'float16': 0.270},
(64, 8192, 8192): {'float16': 0.459},
(1024, 64, 1024): {'float16': 0.0692},
(4096, 64, 4096): {'float16': 0.264},
(8192, 64, 8192): {'float16': 0.452},
# Non pow 2 shapes
(1000, 200, 100): {'float16': 0.084},
(1000, 200, 700): {'float16': 0.084},
(994, 136, 402): {'float16': 0.084},
(995, 135, 409): {'float16': 0.084},
(99, 1357, 409): {'float16': 0.084},
},
# NOTE:
# A100 in the CI server is slow-ish for some reason.
# On some other servers, we are getting about 90% peak for 8kx8x8k float16
'a100': {
# square
(512, 512, 512): {'float16': 0.084, 'float32': 0.12, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.332, 'float32': 0.352, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.635, 'float32': 0.522, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.750, 'float32': 0.810, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.760, 'float32': 0.760, 'int8': 0.51},
(512, 512, 512): {'float16': 0.061, 'float32': 0.097, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.283, 'float32': 0.313, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.618, 'float32': 0.532, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.751, 'float32': 0.726, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.786, 'float32': 0.754, 'int8': 0.51},
# tall-skinny
(16, 1024, 1024): {'float16': 0.008, 'float32': 0.009, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.036, 'float32': 0.038, 'int8': 0.026},
(16, 8192, 8192): {'float16': 0.056, 'float32': 0.061, 'int8': 0.043},
(64, 1024, 1024): {'float16': 0.020, 'float32': 0.030, 'int8': 0.017},
(64, 4096, 4096): {'float16': 0.160, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.280, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.040, 'float32': 0.050, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.160, 'float32': 0.200, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.250, 'float32': 0.23, 'int8': 0.177},
# Non pow 2 shapes
(1000, 200, 100): {'float16': 0.011, 'float32': 0.017, 'int8': 0.05},
(1000, 200, 700): {'float16': 0.027, 'float32': 0.047, 'int8': 0.05},
(994, 136, 402): {'float16': 0.015, 'float32': 0.024, 'int8': 0.05},
(995, 135, 409): {'float16': 0.015, 'float32': 0.025, 'int8': 0.05},
(99, 1357, 409): {'float16': 0.011, 'float32': 0.036, 'int8': 0.05}
(16, 1024, 1024): {'float16': 0.006, 'float32': 0.009, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.057, 'float32': 0.051, 'int8': 0.026},
(16, 8192, 8192): {'float16': 0.077, 'float32': 0.077, 'int8': 0.043},
(64, 1024, 1024): {'float16': 0.018, 'float32': 0.023, 'int8': 0.017},
(64, 4096, 4096): {'float16': 0.150, 'float32': 0.000, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.338, 'float32': 0.000, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.029, 'float32': 0.046, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.179, 'float32': 0.214, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.278, 'float32': 0.000, 'int8': 0.177},
}
}
@@ -97,6 +65,8 @@ matmul_data = {
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16', 'float32']])
def test_matmul(M, N, K, dtype_str):
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
pytest.skip('Only test float32 & int8 on a100')
if (M, N, K) in [(64, 4096, 4096), (64, 8192, 8192), (8192, 64, 8192)] and dtype_str == 'float32':
@@ -114,11 +84,11 @@ def test_matmul(M, N, K, dtype_str):
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((K, N), dtype=dtype, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, return_mode="min", warmup=100, rep=300)
ms = triton.testing.do_bench_cudagraph(fn)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
print_perf(ms, cur_gpu_util, ref_gpu_util)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
#######################
@@ -140,31 +110,17 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
elementwise_data = {
'v100': {
1024 * 16: {'float16': 0.0219, 'float32': 0.010},
1024 * 64: {'float16': 0.0791, 'float32': 0.010},
1024 * 256: {'float16': 0.243, 'float32': 0.010},
1024 * 1024: {'float16': 0.530, 'float32': 0.010},
1024 * 4096: {'float16': 0.796, 'float32': 0.010},
1024 * 16384: {'float16': 0.905, 'float32': 0.010},
1024 * 65536: {'float16': 0.939, 'float32': 0.010},
# Non pow 2
1020 * 100: {'float16': 0.010, 'float32': 0.010},
995 * 125: {'float16': 0.010, 'float32': 0.010},
10003 * 7007: {'float16': 0.010, 'float32': 0.010},
},
'a100': {
1024 * 16: {'float16': 0.010, 'bfloat16': 0.010, 'float32': 0.020},
1024 * 64: {'float16': 0.040, 'bfloat16': 0.040, 'float32': 0.066},
1024 * 256: {'float16': 0.132, 'bfloat16': 0.132, 'float32': 0.227},
1024 * 1024: {'float16': 0.353, 'bfloat16': 0.353, 'float32': 0.488},
1024 * 4096: {'float16': 0.605, 'bfloat16': 0.605, 'float32': 0.705},
1024 * 16384: {'float16': 0.758, 'bfloat16': 0.750, 'float32': 0.819},
1024 * 65536: {'float16': 0.850, 'bfloat16': 0.850, 'float32': 0.870},
1024 * 16: {'float16': 0.003, 'float32': 0.007},
1024 * 64: {'float16': 0.013, 'float32': 0.026},
1024 * 256: {'float16': 0.053, 'float32': 0.105},
1024 * 1024: {'float16': 0.212, 'float32': 0.420},
1024 * 4096: {'float16': 0.791, 'float32': 0.668},
1024 * 16384: {'float16': 0.762, 'float32': 0.812},
1024 * 65536: {'float16': 0.846, 'float32': 0.869},
# Non pow 2
1020 * 100: {'float16': 0.051, 'bfloat16': 0.051, 'float32': 0.103},
995 * 125: {'float16': 0.063, 'bfloat16': 0.063, 'float32': 0.126},
10003 * 7007: {'float16': 0.544, 'bfloat16': 0.541, 'float32': 0.861},
1020 * 100: {'float16': 0.020, 'float32': 0.041},
10003 * 7007: {'float16': 0.513, 'float32': 0.861},
}
}
@@ -172,22 +128,25 @@ elementwise_data = {
@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32'])
def test_elementwise(N, dtype_str):
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
torch.manual_seed(0)
if dtype_str in ['bfloat16'] and DEVICE_NAME != 'a100':
pytest.skip('Only test bfloat16 on a100')
dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str]
ref_gpu_util = elementwise_data[DEVICE_NAME][N][dtype_str]
ref_dtype_str = 'float16' if dtype_str == 'bfloat16' else dtype_str
ref_gpu_util = elementwise_data[DEVICE_NAME][N][ref_dtype_str]
max_gpu_perf = get_dram_gbps()
z = torch.empty((N, ), dtype=dtype, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, return_mode="min", warmup=100, rep=500)
ms = triton.testing.do_bench_cudagraph(fn)
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
print_perf(ms, cur_gpu_util, ref_gpu_util)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
#######################
# Flash-Attention
@@ -196,29 +155,29 @@ def test_elementwise(N, dtype_str):
flash_attention_data = {
"a100": {
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.420,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.202,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.355,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.201,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.099,
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.424,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.379,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.098,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.201,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.199,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.087,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.238,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.240,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.210,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.061,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.135,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.211,
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.062,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.424,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.262,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.370,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.254,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.378,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.099,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.262,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.254,
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.125,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.238,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.158,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.211,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.134,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.062,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.158,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.134,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.075,
}
}
@@ -230,6 +189,8 @@ flash_attention_data = {
@pytest.mark.parametrize("seq_par", [True, False])
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]])
def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
is_backward = mode == 'backward'
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
@@ -250,7 +211,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, return_mode="min", warmup=100, rep=500)
ms = triton.testing.do_bench_cudagraph(fn)
# compute flops
flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 2 * flops_per_matmul
@@ -263,4 +224,4 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
cur_gpu_util = cur_gpu_perf / max_gpu_perf
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str)]
print_perf(ms, cur_gpu_util, ref_gpu_util)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)