From 25565b2410e738ca57f09691a98e6ca38ee2acbf Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Sun, 22 Feb 2026 21:47:36 -0800 Subject: [PATCH] fa: test for mp (#14907) --- extra/thunder/amd/fa.py | 35 +++++++++++++++----------- test/testextra/test_hk_fa.py | 49 +++++++++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 15 deletions(-) diff --git a/extra/thunder/amd/fa.py b/extra/thunder/amd/fa.py index 698f6d1b7a..e9501d9a82 100644 --- a/extra/thunder/amd/fa.py +++ b/extra/thunder/amd/fa.py @@ -11,7 +11,8 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None, dtype:DTypeLike|None=None) -> Tensor: dtype = dtype or ref.dtype if not isinstance(ref.device, tuple): return Tensor.empty(*shape, dtype=dtype, device=ref.device) - shape = tuple(s // len(ref.device) if i == ref.uop.axis else s for i, s in enumerate(shape)) + shard_axis = ref.uop.axis if axis is None else axis + shape = tuple(s // len(ref.device) if i == shard_axis else s for i, s in enumerate(shape)) axis = ref.uop.axis if axis is None else axis return Tensor(Tensor.empty(*shape, dtype=dtype, device=ref.device).uop.multi(axis), dtype=dtype, device=ref.device) @@ -29,34 +30,40 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False assert D == 128, "only D=128 supported" num_devices = len(xq.device) if isinstance(xq.device, tuple) else 1 - B_local = B // num_devices - if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {H_KV=} {D=}") + is_dp = xq.uop.axis == 0 + is_mp = xq.uop.axis == 2 + B_local = B // num_devices if is_dp else B + H_local = H // num_devices if is_mp else H + H_KV_local = H_KV // num_devices if is_mp else H_KV + shard_axis = 0 if is_dp else 2 if is_mp else None + shard_axis_t = 0 if is_dp else 1 if is_mp else None + if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {H_local=} {H_KV=} {H_KV_local=} {D=} on {num_devices} devices, {'DP' if is_dp else 'MP' if is_mp else 'no sharding'}") single_device = xq.device[0] if isinstance(xq.device, tuple) else xq.device arch = Device[single_device].renderer.arch - attn = _sharded_empty_like(xq, axis=0) - l_vec = _sharded_empty((B, H, 1, N), xq, axis=0, dtype=dtypes.float32) + attn = _sharded_empty_like(xq, axis=shard_axis) + l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t) def grad(dou:UOp, _) -> tuple[None, None, UOp, UOp, UOp]: do = Tensor(dou, device=dou.device) - dq_in = _sharded_empty((B, H, N, D), xq, axis=0) - dq = _sharded_empty_like(xq, axis=0) - dk = _sharded_empty_like(xk, axis=0) - dv = _sharded_empty_like(xv, axis=0) + dq_in = _sharded_empty((B, H, N, D), xq, axis=shard_axis_t) + dq = _sharded_empty_like(xq, axis=shard_axis) + dk = _sharded_empty_like(xk, axis=shard_axis) + dv = _sharded_empty_like(xv, axis=shard_axis) # delta_vec = (do * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach() - delta_vec = _sharded_empty((B, H, 1, N), xq, axis=0, dtype=dtypes.float32) - delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:2] + delta_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t) + delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:2] - dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:3] + dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:3] # unshuffle dq - dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[0] + dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[0] return None, None, dq.uop, dk.uop, dv.uop - attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D), grad_fxn=grad)[:2] + attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D), grad_fxn=grad)[:2] return attn.transpose(1, 2) diff --git a/test/testextra/test_hk_fa.py b/test/testextra/test_hk_fa.py index a3c4c87a64..c472145157 100644 --- a/test/testextra/test_hk_fa.py +++ b/test/testextra/test_hk_fa.py @@ -128,7 +128,7 @@ class TestFA(unittest.TestCase): assert_allclose(k.grad, k_ref.grad, atol=1e-5, rtol=1e-5) assert_allclose(v.grad, v_ref.grad, atol=1e-5, rtol=1e-5) - def test_fast_fa_bwd_multidevice(self): + def test_fast_fa_bwd_dp(self): Tensor.manual_seed(42) B, N, H, H_KV, D = 2, 1024, 32, 8, 128 @@ -175,5 +175,52 @@ class TestFA(unittest.TestCase): assert_allclose(v.grad, v_ref.grad, atol=1e-5, rtol=1e-5) assert_allclose(k.grad, k_ref.grad, atol=1e-5, rtol=1e-5) + def test_fast_fa_bwd_mp(self): + Tensor.manual_seed(42) + + B, N, H, H_KV, D = 2, 1024, 32, 8, 128 + GPUS = tuple(f"AMD:{i}" for i in range(B)) + + with Context(DEBUG=0): + base_q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous() + base_k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous() + base_v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous() + + base_do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous() + + with Context(DEBUG=0): + q = base_q.clone().requires_grad_(True).shard(GPUS, axis=2) + k = base_k.clone().requires_grad_(True).shard(GPUS, axis=2) + v = base_v.clone().requires_grad_(True).shard(GPUS, axis=2) + Tensor.realize(q, k, v) + + do = base_do.clone().shard(GPUS, axis=2) + Tensor.realize(do) + + q_, k_, v_ = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + out = flash_attention(q_, k_, v_, is_causal=True) + out = out.float().transpose(1, 2) + out.backward(do) + Tensor.realize(q.grad, k.grad, v.grad) + + with Context(DEBUG=0): + q_ref = base_q.clone().requires_grad_(True) + k_ref = base_k.clone().requires_grad_(True) + v_ref = base_v.clone().requires_grad_(True) + Tensor.realize(q_ref, k_ref, v_ref) + + do_ref = base_do.clone() + Tensor.realize(do_ref) + + q_ref_, k_ref_, v_ref_ = q_ref.transpose(1, 2), k_ref.transpose(1, 2), v_ref.transpose(1, 2) + ref = flash_attention(q_ref_, k_ref_, v_ref_, is_causal=True) + ref = ref.float().transpose(1, 2) + ref.backward(do_ref) + Tensor.realize(q_ref.grad, k_ref.grad, v_ref.grad) + + assert_allclose(q.grad, q_ref.grad, atol=1e-5, rtol=1e-5) + assert_allclose(v.grad, v_ref.grad, atol=1e-5, rtol=1e-5) + assert_allclose(k.grad, k_ref.grad, atol=1e-5, rtol=1e-5) + if __name__ == "__main__": unittest.main()