diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh index 211b83ae4b..f6b28dce88 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh @@ -8,9 +8,11 @@ export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000 export DEBUG=${DEBUG:-2} export FLASH_ATTENTION=${FLASH_ATTENTION:-1} export ALL2ALL=${ALL2ALL:-1} +export USE_ATOMICS=${USE_ATOMICS:-1} +export ASM_GEMM=${ASM_GEMM:-1} export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16" -export DP=8 BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2 +export DP=8 BS=16 EVAL_BS=8 GRADIENT_ACC_STEPS=1 export GBS=$((BS * GRADIENT_ACC_STEPS)) export MODEL="llama3" diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh index 756774b7b0..70b750af86 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh @@ -13,7 +13,7 @@ export USE_ATOMICS=${USE_ATOMICS:-1} export ASM_GEMM=${ASM_GEMM:-1} export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16" -export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2} +export DP=${DP:-8} BS=${BS:-16} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-1} export GBS=$((BS * GRADIENT_ACC_STEPS)) export MODEL="llama3" diff --git a/extra/thunder/tiny/fa.py b/extra/thunder/tiny/fa.py index 08e8ebefb9..a3e1ce059d 100644 --- a/extra/thunder/tiny/fa.py +++ b/extra/thunder/tiny/fa.py @@ -2,7 +2,7 @@ import math from tinygrad import Tensor, dtypes from tinygrad.helpers import DEBUG -from tinygrad.uop.ops import UOp +from tinygrad.uop.ops import UOp, Ops from extra.thunder.tiny.tk import WARP_THREADS from extra.thunder.tiny.tk.kernel import Kernel @@ -43,11 +43,12 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False B_local = B // num_devices if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {D=} {H_KV=} {GROUP_SIZE=}") - def custom_forward(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp) -> UOp: + def _custom_forward_impl(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp|None) -> UOp: with Kernel("fa_custom_forward", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B_local), NUM_WORKERS * WARP_THREADS) as ker: warp = ker.warp - o, q, k, v, mask, l_vec = GL(ou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker), GL(l_vecu, ker) + o, q, k, v, l_vec = GL(ou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(l_vecu, ker) + mask = GL(masku, ker) if masku is not None else None head = ker.blockIdx_x head_kv = head // GROUP_SIZE @@ -86,7 +87,8 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False q_reg = warp.copy(q_reg, q_reg_fl) q_reg_transposed = warp.transpose(q_reg_transposed, q_reg) - for kv_idx in ker.range(N // KV_BLOCK_SIZE): + num_kv_blocks = (q_seq + 1) if is_causal else (N // KV_BLOCK_SIZE) + for kv_idx in ker.range(num_kv_blocks): k_smem = warp.load(k_smem, k, (), (batch, kv_idx, head_kv, 0), axis=1) v_smem = warp.load(v_smem, v, (), (batch, kv_idx, head_kv, 0), axis=1) @@ -99,9 +101,16 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False att_block = warp.mma_AtB(att_block, k_reg_transposed, q_reg_transposed) # apply attention mask - mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2) - mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg) - att_block += mask_reg_transposed + if is_causal: + bs_rows, bs_cols, bs_stride = att_block.base_shape.rows, att_block.base_shape.cols, att_block.base_shape.stride + q_base = q_seq * Q_BLOCK_SIZE + (warp.laneid % bs_cols) + kv_base = kv_idx * KV_BLOCK_SIZE + (warp.laneid // bs_cols) * bs_stride + att_block = warp.map(att_block, + lambda x, idx: ((kv_base + idx[0]*bs_rows + idx[2]) > (q_base + idx[1]*bs_cols)).alu(Ops.WHERE, UOp.ufix(x._uop, -math.inf), x)) + elif mask is not None: + mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2) + mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg) + att_block += mask_reg_transposed # softmax max_vec_last = warp.copy(max_vec_last.after(kv_idx), max_vec) @@ -141,11 +150,18 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False return ker.finish() - def custom_backward_q(dqu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp: + def custom_forward_causal(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp) -> UOp: + return _custom_forward_impl(ou, l_vecu, qu, ku, vu, None) + + def custom_forward_masked(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp) -> UOp: + return _custom_forward_impl(ou, l_vecu, qu, ku, vu, masku) + + def _custom_backward_q_impl(dqu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp|None, l_vecu:UOp, delta_vecu:UOp) -> UOp: with Kernel("fa_custom_backward_q", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B_local), NUM_WORKERS * WARP_THREADS) as ker: warp = ker.warp - dq, do, q, k, v, mask = GL(dqu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker) + dq, do, q, k, v = GL(dqu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker) + mask = GL(masku, ker) if masku is not None else None l_vec, delta_vec = GL(l_vecu, ker), GL(delta_vecu, ker) head = ker.blockIdx_x @@ -194,7 +210,8 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False l_vec_reg *= 1.0 / math.log(2) delta_vec_reg = warp.load(delta_vec_reg, delta_vec, (), (batch, head, 0, q_seq), axis=2) - for kv_idx in ker.range(N // KV_BLOCK_SIZE): + num_kv_blocks = (q_seq + 1) if is_causal else (N // KV_BLOCK_SIZE) + for kv_idx in ker.range(num_kv_blocks): k_smem = warp.load(k_smem, k, (), (batch, kv_idx, head_kv, 0), axis=1) v_smem = warp.load(v_smem, v, (), (batch, kv_idx, head_kv, 0), axis=1) @@ -209,9 +226,16 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False att_block = warp.mma_AtB(att_block, k_reg_t, q_reg_t) # apply attention mask - mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2) - mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg) - att_block += mask_reg_transposed + if is_causal: + bs_rows, bs_cols, bs_stride = att_block.base_shape.rows, att_block.base_shape.cols, att_block.base_shape.stride + q_base = q_seq * Q_BLOCK_SIZE + (warp.laneid % bs_cols) + kv_base = kv_idx * KV_BLOCK_SIZE + (warp.laneid // bs_cols) * bs_stride + att_block = warp.map(att_block, + lambda x, idx: ((kv_base + idx[0]*bs_rows + idx[2]) > (q_base + idx[1]*bs_cols)).alu(Ops.WHERE, UOp.ufix(x._uop, -math.inf), x)) + elif mask is not None: + mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2) + mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg) + att_block += mask_reg_transposed att_block -= l_vec_reg att_block = att_block.exp2() @@ -231,11 +255,18 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False return ker.finish() - def custom_backward_kv(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp): + def custom_backward_q_causal(dqu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp: + return _custom_backward_q_impl(dqu, dou, qu, ku, vu, None, l_vecu, delta_vecu) + + def custom_backward_q_masked(dqu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp: + return _custom_backward_q_impl(dqu, dou, qu, ku, vu, masku, l_vecu, delta_vecu) + + def _custom_backward_kv_impl(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp|None, l_vecu:UOp, delta_vecu:UOp): with Kernel("fa_custom_backward_kv", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B_local), NUM_WORKERS * WARP_THREADS) as ker: warp = ker.warp - dk, dv, do, q, k, v, mask = GL(dku, ker), GL(dvu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker) + dk, dv, do, q, k, v = GL(dku, ker), GL(dvu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker) + mask = GL(masku, ker) if masku is not None else None l_vec, delta_vec = GL(l_vecu, ker), GL(delta_vecu, ker) head_kv = ker.blockIdx_x @@ -302,9 +333,16 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False att_block *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2)) # apply attention mask - mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_idx, kv_seq), axis=2) - mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg) - att_block += mask_reg_transposed + if is_causal: + bs_rows, bs_cols, bs_stride = att_block.base_shape.rows, att_block.base_shape.cols, att_block.base_shape.stride + q_base = q_idx * Q_BLOCK_SIZE + (warp.laneid % bs_cols) + kv_base = kv_seq * KV_BLOCK_SIZE + (warp.laneid // bs_cols) * bs_stride + att_block = warp.map(att_block, + lambda x, idx: ((kv_base + idx[0]*bs_rows + idx[2]) > (q_base + idx[1]*bs_cols)).alu(Ops.WHERE, UOp.ufix(x._uop, -math.inf), x)) + elif mask is not None: + mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_idx, kv_seq), axis=2) + mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg) + att_block += mask_reg_transposed att_block -= l_vec_reg att_block = att_block.exp2() @@ -336,24 +374,31 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False return ker.finish(2) + def custom_backward_kv_causal(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, l_vecu:UOp, delta_vecu:UOp): + return _custom_backward_kv_impl(dku, dvu, dou, qu, ku, vu, None, l_vecu, delta_vecu) + + def custom_backward_kv_masked(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp): + return _custom_backward_kv_impl(dku, dvu, dou, qu, ku, vu, masku, l_vecu, delta_vecu) + single_device = xq.device[0] if isinstance(xq.device, tuple) else xq.device if is_causal: if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True") - attn_mask = Tensor.ones((B, 1, N, N), requires_grad=False, device=single_device, dtype=dtypes.bool).tril() - if attn_mask is not None: + elif attn_mask is not None: if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf")) + if attn_mask.shape != (B, 1, N, N): + attn_mask = attn_mask.expand(B, 1, N, N) + if isinstance(xq.device, tuple) and not isinstance(attn_mask.device, tuple): + attn_mask = attn_mask.shard(xq.device, axis=0) else: attn_mask = Tensor.zeros((B, 1, N, N), requires_grad=False, device=single_device, dtype=dtypes.float32) - if attn_mask.shape != (B, 1, N, N): - attn_mask = attn_mask.expand(B, 1, N, N) - if isinstance(xq.device, tuple) and not isinstance(attn_mask.device, tuple): - attn_mask = attn_mask.shard(xq.device, axis=0) + if isinstance(xq.device, tuple): + attn_mask = attn_mask.shard(xq.device, axis=0) attn = _sharded_empty_like(xq, axis=0) l_vec = _sharded_empty((B, H, 1, N), xq, axis=0) - def grad(gradu:UOp, _) -> tuple[None, None, UOp, UOp, UOp, None]: + def grad_causal(gradu:UOp, _) -> tuple[None, None, UOp, UOp, UOp]: grad = Tensor(gradu, device=gradu.device) grad_q = _sharded_empty_like(xq, axis=0) grad_k = _sharded_empty_like(xk, axis=0) @@ -361,11 +406,26 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False delta_vec = (grad * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach() - grad_q = Tensor.custom_kernel(grad_q, grad, xq, xk, xv, attn_mask, l_vec, delta_vec, fxn=custom_backward_q)[0] - grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, xq, xk, xv, attn_mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2] + grad_q = Tensor.custom_kernel(grad_q, grad, xq, xk, xv, l_vec, delta_vec, fxn=custom_backward_q_causal)[0] + grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, xq, xk, xv, l_vec, delta_vec, fxn=custom_backward_kv_causal)[:2] + return (None, None, grad_q.uop, grad_k.uop, grad_v.uop) + + def grad_masked(gradu:UOp, _) -> tuple[None, None, UOp, UOp, UOp, None]: + grad = Tensor(gradu, device=gradu.device) + grad_q = _sharded_empty_like(xq, axis=0) + grad_k = _sharded_empty_like(xk, axis=0) + grad_v = _sharded_empty_like(xv, axis=0) + + delta_vec = (grad * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach() + + grad_q = Tensor.custom_kernel(grad_q, grad, xq, xk, xv, attn_mask, l_vec, delta_vec, fxn=custom_backward_q_masked)[0] + grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, xq, xk, xv, attn_mask, l_vec, delta_vec, fxn=custom_backward_kv_masked)[:2] return (None, None, grad_q.uop, grad_k.uop, grad_v.uop, None) - attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, attn_mask, fxn=custom_forward, grad_fxn=grad)[:2] + if is_causal: + attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=custom_forward_causal, grad_fxn=grad_causal)[:2] + else: + attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, attn_mask, fxn=custom_forward_masked, grad_fxn=grad_masked)[:2] attn_ = attn[:, :N_, :, :D_] return attn_.transpose(1, 2).cast(odtype) diff --git a/test/testextra/test_tk.py b/test/testextra/test_tk.py index 362f97c7c3..2fae9efdbe 100644 --- a/test/testextra/test_tk.py +++ b/test/testextra/test_tk.py @@ -750,6 +750,35 @@ class TestTK(unittest.TestCase): fa_jitted = TinyJit(flash_attention) + for _ in range(10): + st = time.perf_counter() + out = fa_jitted(q, k, v, is_causal=False) + et = time.perf_counter() - st + attn_flops = 2 * B * H * N * N * D + \ + 4 * B * H * N * N + \ + 2 * B * H * N * N * D + print(f"{attn_flops/(et*1e9):2f} GFLOPS") + out = out.float().transpose(1, 2) + + ref = q.scaled_dot_product_attention(k, v, is_causal=False, enable_gqa=True).float().transpose(1, 2) + + np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2) + + def test_fast_fa_causal(self): + from extra.thunder.tiny.fa import flash_attention + + B, N, H, H_KV, D = 2, 8192, 32, 8, 128 + + with Context(DEBUG=0): + q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16).contiguous() + k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous() + v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous() + Tensor.realize(q, k, v) + + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + fa_jitted = TinyJit(flash_attention) + for _ in range(10): st = time.perf_counter() out = fa_jitted(q, k, v, is_causal=True) @@ -838,7 +867,7 @@ class TestTK(unittest.TestCase): np.testing.assert_allclose(q.grad.numpy(), q_ref.grad.numpy(), atol=2e-2, rtol=2e-2) np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2) - np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2) + np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=6e-2, rtol=2e-2) def test_fast_fa_bwd_causal_jitted(self): from extra.thunder.tiny.fa import flash_attention