mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Enable backward pass in FA tutorial test (#282)
Enabled the backward pass in the fused attention tutorial. The tolerance when comparing to the naive implementation had to be changed. The block size is forced to be 64x64 due to the 64 KiB LDS. Default is block 128 for A100's larger SMEM. This creates differences in order of computation and reuslts in a larger gap between the naive and FA implementations.
This commit is contained in:
@@ -292,28 +292,28 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
|
||||
if torch.version.hip is None:
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# # triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
# print(ref_out)
|
||||
# print(tri_out)
|
||||
if torch.version.hip is None:
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
|
||||
if torch.version.hip is None:
|
||||
# TODO: Enable backward pass for MFMA dot.
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
|
||||
# The current block size for MI200 series is 64x64. This results in
|
||||
# larger differences in float results due to rounding.
|
||||
else:
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-1, rtol=0)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user