mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTS] Added attention regression tests (#1227)
This commit is contained in:
5
.github/workflows/integration-tests.yml
vendored
5
.github/workflows/integration-tests.yml
vendored
@@ -3,10 +3,9 @@ name: Integration Tests
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- triton-mlir
|
||||
branches: [main]
|
||||
merge_group:
|
||||
branches: [main]
|
||||
types: [checks_requested]
|
||||
|
||||
concurrency:
|
||||
|
||||
@@ -57,7 +57,7 @@ matmul_data = {
|
||||
(512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05},
|
||||
(1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169},
|
||||
(2048, 2048, 2048): {'float16': 0.64, 'float32': 0.57, 'int8': 0.34},
|
||||
(4096, 4096, 4096): {'float16': 0.82, 'float32': 0.75, 'int8': 0.46},
|
||||
(4096, 4096, 4096): {'float16': 0.80, 'float32': 0.75, 'int8': 0.46},
|
||||
(8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
|
||||
@@ -153,3 +153,51 @@ def test_elementwise(N):
|
||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
||||
#######################
|
||||
# Flash-Attention
|
||||
#######################
|
||||
|
||||
|
||||
flash_attention_data = {
|
||||
"a100": {
|
||||
(4, 48, 4096, 64, 'forward', 'float16'): 0.37,
|
||||
(4, 48, 4096, 64, 'backward', 'float16'): 0.25,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]])
|
||||
@pytest.mark.parametrize("mode", ['forward', 'backward'])
|
||||
@pytest.mark.parametrize("dtype_str", ['float16'])
|
||||
def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
|
||||
is_backward = mode == 'backward'
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability < 80")
|
||||
torch.manual_seed(20)
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
|
||||
# init data
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
|
||||
sm_scale = 0.2
|
||||
# benchmark
|
||||
fn = lambda: triton.ops.attention(q, k, v, sm_scale)
|
||||
if is_backward:
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
|
||||
# compute flops
|
||||
flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5
|
||||
total_flops = 2 * flops_per_matmul
|
||||
if is_backward:
|
||||
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
|
||||
cur_gpu_perf = total_flops / ms * 1e-9
|
||||
# maximum flops
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
||||
Reference in New Issue
Block a user