[TESTS] Added attention regression tests (#1227)

This commit is contained in:
Philippe Tillet
2023-02-21 20:22:36 -08:00
committed by GitHub
parent 0192ab2178
commit 174f121c1c
2 changed files with 51 additions and 4 deletions

View File

@@ -57,7 +57,7 @@ matmul_data = {
(512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.64, 'float32': 0.57, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.82, 'float32': 0.75, 'int8': 0.46},
(4096, 4096, 4096): {'float16': 0.80, 'float32': 0.75, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
@@ -153,3 +153,51 @@ def test_elementwise(N):
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
#######################
# Flash-Attention
#######################
flash_attention_data = {
"a100": {
(4, 48, 4096, 64, 'forward', 'float16'): 0.37,
(4, 48, 4096, 64, 'backward', 'float16'): 0.25,
}
}
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]])
@pytest.mark.parametrize("mode", ['forward', 'backward'])
@pytest.mark.parametrize("dtype_str", ['float16'])
def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
is_backward = mode == 'backward'
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
# init data
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
# benchmark
fn = lambda: triton.ops.attention(q, k, v, sm_scale)
if is_backward:
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
# compute flops
flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 2 * flops_per_matmul
if is_backward:
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
cur_gpu_perf = total_flops / ms * 1e-9
# maximum flops
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
cur_gpu_util = cur_gpu_perf / max_gpu_perf
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)