mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix for FP8 QK inputs in flash attention forward pass (#2435)
This commit is contained in:
@@ -434,7 +434,7 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src && dst && src.getVersionMajor() == 3 &&
|
||||
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
|
||||
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
|
||||
dst.getWarpsPerCTA()[1] == 1;
|
||||
}
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
|
||||
@@ -596,6 +596,9 @@ def bench_flash_attention(
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
if mode == "fwd":
|
||||
q = q.to(torch.float8_e5m2)
|
||||
k = k.to(torch.float8_e5m2)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, causal, sm_scale)
|
||||
|
||||
Reference in New Issue
Block a user