Enable backward pass in FA tutorial test (#282)

Enabled the backward pass in the fused attention tutorial.
The tolerance when comparing to the naive implementation
had to be changed. The block size is forced to be 64x64
due to the 64 KiB LDS. Default is block 128 for A100's
larger SMEM. This creates differences in order of computation
and reuslts in a larger gap between the naive and FA
implementations.
This commit is contained in:
Vinayak Gokhale
2023-08-03 10:12:46 -05:00
committed by GitHub
parent 31cfda8f0e
commit f1063bb33c

View File

@@ -292,28 +292,28 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
if torch.version.hip is None:
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = attention(q, k, v, sm_scale)
# print(ref_out)
# print(tri_out)
if torch.version.hip is None:
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
if torch.version.hip is None:
# TODO: Enable backward pass for MFMA dot.
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0)
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
# The current block size for MI200 series is 64x64. This results in
# larger differences in float results due to rounding.
else:
assert torch.allclose(ref_dv, tri_dv, atol=1e-1, rtol=0)
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
try: