mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TEST] Added flash attention tests for D_HEAD in {16, 32, 128}. (#1709)
This commit is contained in:
@@ -55,10 +55,6 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
changed = false;
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
|
||||
@@ -5,7 +5,10 @@ import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
|
||||
(4, 48, 1024, 32),
|
||||
(4, 48, 1024, 64),
|
||||
(4, 48, 1024, 128)])
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
|
||||
@@ -203,8 +203,7 @@ class _attention(torch.autograd.Function):
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
# assert Lk in {16, 32, 64, 128}
|
||||
assert Lk in {64} # TODO: fix other cases
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user