fa: test for mp (#14907)

This commit is contained in:
wozeparrot
2026-02-22 21:47:36 -08:00
committed by GitHub
parent d6145736c7
commit 25565b2410
2 changed files with 69 additions and 15 deletions

View File

@@ -11,7 +11,8 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo
def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None, dtype:DTypeLike|None=None) -> Tensor:
dtype = dtype or ref.dtype
if not isinstance(ref.device, tuple): return Tensor.empty(*shape, dtype=dtype, device=ref.device)
shape = tuple(s // len(ref.device) if i == ref.uop.axis else s for i, s in enumerate(shape))
shard_axis = ref.uop.axis if axis is None else axis
shape = tuple(s // len(ref.device) if i == shard_axis else s for i, s in enumerate(shape))
axis = ref.uop.axis if axis is None else axis
return Tensor(Tensor.empty(*shape, dtype=dtype, device=ref.device).uop.multi(axis), dtype=dtype, device=ref.device)
@@ -29,34 +30,40 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
assert D == 128, "only D=128 supported"
num_devices = len(xq.device) if isinstance(xq.device, tuple) else 1
B_local = B // num_devices
if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {H_KV=} {D=}")
is_dp = xq.uop.axis == 0
is_mp = xq.uop.axis == 2
B_local = B // num_devices if is_dp else B
H_local = H // num_devices if is_mp else H
H_KV_local = H_KV // num_devices if is_mp else H_KV
shard_axis = 0 if is_dp else 2 if is_mp else None
shard_axis_t = 0 if is_dp else 1 if is_mp else None
if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {H_local=} {H_KV=} {H_KV_local=} {D=} on {num_devices} devices, {'DP' if is_dp else 'MP' if is_mp else 'no sharding'}")
single_device = xq.device[0] if isinstance(xq.device, tuple) else xq.device
arch = Device[single_device].renderer.arch
attn = _sharded_empty_like(xq, axis=0)
l_vec = _sharded_empty((B, H, 1, N), xq, axis=0, dtype=dtypes.float32)
attn = _sharded_empty_like(xq, axis=shard_axis)
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
def grad(dou:UOp, _) -> tuple[None, None, UOp, UOp, UOp]:
do = Tensor(dou, device=dou.device)
dq_in = _sharded_empty((B, H, N, D), xq, axis=0)
dq = _sharded_empty_like(xq, axis=0)
dk = _sharded_empty_like(xk, axis=0)
dv = _sharded_empty_like(xv, axis=0)
dq_in = _sharded_empty((B, H, N, D), xq, axis=shard_axis_t)
dq = _sharded_empty_like(xq, axis=shard_axis)
dk = _sharded_empty_like(xk, axis=shard_axis)
dv = _sharded_empty_like(xv, axis=shard_axis)
# delta_vec = (do * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach()
delta_vec = _sharded_empty((B, H, 1, N), xq, axis=0, dtype=dtypes.float32)
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:2]
delta_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:2]
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:3]
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:3]
# unshuffle dq
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[0]
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[0]
return None, None, dq.uop, dk.uop, dv.uop
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D), grad_fxn=grad)[:2]
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D), grad_fxn=grad)[:2]
return attn.transpose(1, 2)

View File

@@ -128,7 +128,7 @@ class TestFA(unittest.TestCase):
assert_allclose(k.grad, k_ref.grad, atol=1e-5, rtol=1e-5)
assert_allclose(v.grad, v_ref.grad, atol=1e-5, rtol=1e-5)
def test_fast_fa_bwd_multidevice(self):
def test_fast_fa_bwd_dp(self):
Tensor.manual_seed(42)
B, N, H, H_KV, D = 2, 1024, 32, 8, 128
@@ -175,5 +175,52 @@ class TestFA(unittest.TestCase):
assert_allclose(v.grad, v_ref.grad, atol=1e-5, rtol=1e-5)
assert_allclose(k.grad, k_ref.grad, atol=1e-5, rtol=1e-5)
def test_fast_fa_bwd_mp(self):
Tensor.manual_seed(42)
B, N, H, H_KV, D = 2, 1024, 32, 8, 128
GPUS = tuple(f"AMD:{i}" for i in range(B))
with Context(DEBUG=0):
base_q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
base_k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
base_v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
base_do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
with Context(DEBUG=0):
q = base_q.clone().requires_grad_(True).shard(GPUS, axis=2)
k = base_k.clone().requires_grad_(True).shard(GPUS, axis=2)
v = base_v.clone().requires_grad_(True).shard(GPUS, axis=2)
Tensor.realize(q, k, v)
do = base_do.clone().shard(GPUS, axis=2)
Tensor.realize(do)
q_, k_, v_ = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
out = flash_attention(q_, k_, v_, is_causal=True)
out = out.float().transpose(1, 2)
out.backward(do)
Tensor.realize(q.grad, k.grad, v.grad)
with Context(DEBUG=0):
q_ref = base_q.clone().requires_grad_(True)
k_ref = base_k.clone().requires_grad_(True)
v_ref = base_v.clone().requires_grad_(True)
Tensor.realize(q_ref, k_ref, v_ref)
do_ref = base_do.clone()
Tensor.realize(do_ref)
q_ref_, k_ref_, v_ref_ = q_ref.transpose(1, 2), k_ref.transpose(1, 2), v_ref.transpose(1, 2)
ref = flash_attention(q_ref_, k_ref_, v_ref_, is_causal=True)
ref = ref.float().transpose(1, 2)
ref.backward(do_ref)
Tensor.realize(q_ref.grad, k_ref.grad, v_ref.grad)
assert_allclose(q.grad, q_ref.grad, atol=1e-5, rtol=1e-5)
assert_allclose(v.grad, v_ref.grad, atol=1e-5, rtol=1e-5)
assert_allclose(k.grad, k_ref.grad, atol=1e-5, rtol=1e-5)
if __name__ == "__main__":
unittest.main()