[BACKEND] Fix for FP8 QK inputs in flash attention forward pass (#2435)

This commit is contained in:
Thomas Raoux
2023-10-03 21:02:13 -07:00
committed by GitHub
parent 0d84a7d70c
commit c656a139d3
2 changed files with 4 additions and 1 deletions

View File

@@ -434,7 +434,7 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src && dst && src.getVersionMajor() == 3 &&
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
dst.getWarpsPerCTA()[1] == 1;
}
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {

View File

@@ -596,6 +596,9 @@ def bench_flash_attention(
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if mode == "fwd":
q = q.to(torch.float8_e5m2)
k = k.to(torch.float8_e5m2)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
fn = lambda: attention(q, k, v, causal, sm_scale)