[TEST] Added flash attention tests for D_HEAD in {16, 32, 128}. (#1709)

This commit is contained in:
Philippe Tillet
2023-05-27 22:48:22 -07:00
committed by GitHub
parent f29838a3ea
commit 420e4acecc
3 changed files with 5 additions and 7 deletions

View File

@@ -55,10 +55,6 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)

View File

@@ -5,7 +5,10 @@ import triton
import triton.ops
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
(4, 48, 1024, 32),
(4, 48, 1024, 64),
(4, 48, 1024, 128)])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op(Z, H, N_CTX, D_HEAD, dtype):
capability = torch.cuda.get_device_capability()

View File

@@ -203,8 +203,7 @@ class _attention(torch.autograd.Function):
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# assert Lk in {16, 32, 64, 128}
assert Lk in {64} # TODO: fix other cases
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)